From 9a9dd027dcb661f1312193ec00140774a78a5403 Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Wed, 14 Sep 2022 09:34:16 +0900 Subject: [PATCH] Show CNN architecture --- interface/stdout.py | 14 ++++++++++++++ show.py | 3 ++- stylegan/layers/basic.py | 2 ++ stylegan/layers/discriminator.py | 12 ++++++++++++ stylegan/layers/generator.py | 13 +++++++++++++ stylegan/networks.py | 12 ++++++++++++ train.py | 3 ++- visualize.py | 3 ++- 8 files changed, 59 insertions(+), 3 deletions(-) diff --git a/interface/stdout.py b/interface/stdout.py index af7dcdb..8d1b5da 100644 --- a/interface/stdout.py +++ b/interface/stdout.py @@ -24,6 +24,20 @@ def print_parameter_counts(generator, discriminator=None): print(f"- G: {generator.count_params()}") print(f"- D: {discriminator.count_params()}") +def print_cnn_architecture(generator, discriminator=None): + if discriminator is None: + print("CNN channels:") + else: + print("Generator CNN channels:") + pad = max(max(len(str(s)) for s in s.channels) for _, s in generator.synthesizer.blocks) + for i, s in generator.synthesizer.blocks: + print(f"- Level {i}: " + " -> conv -> ".join(str(c).rjust(pad) for c in s.channels)) + if discriminator is not None: + print("Discriminator CNN channels:") + pad = max(max(len(str(b)) for b in b.channels) for _, b in discriminator.blocks) + for i, b in discriminator.blocks: + print(f"- Level {i}: " + " -> conv -> ".join(str(b).rjust(pad) for b in b.channels)) + def print_training_args(args): if args.accum is None: print(f"Batch size: {args.batch} (Group size: {'entire batch' if args.group == 0 else args.group})") diff --git a/show.py b/show.py index 1b15d51..b4c8dfc 100644 --- a/show.py +++ b/show.py @@ -1,6 +1,6 @@ from stylegan.networks import Generator from interface.args import CustomArgumentParser -from interface.stdout import print_model_args, print_parameter_counts, print_data_classes +from interface.stdout import print_model_args, print_parameter_counts, print_cnn_architecture, print_data_classes from utilities.stdio import eprint def main(args): @@ -8,6 +8,7 @@ def main(args): generator = Generator.load(args.generator) print_model_args(generator) print_parameter_counts(generator) + print_cnn_architecture(generator) print_data_classes(generator) def parse_args(): diff --git a/stylegan/layers/basic.py b/stylegan/layers/basic.py index 56b667a..43ce0c7 100644 --- a/stylegan/layers/basic.py +++ b/stylegan/layers/basic.py @@ -46,6 +46,8 @@ class EqualizedConvolution2D(Link): def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=0, nobias=False, initial_bias=Zero(), gain=sqrt(2)): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.stride = stride self.pad = pad self.c = gain / sqrt(in_channels * ksize ** 2) diff --git a/stylegan/layers/discriminator.py b/stylegan/layers/discriminator.py index 3ddd3d5..e75589f 100644 --- a/stylegan/layers/discriminator.py +++ b/stylegan/layers/discriminator.py @@ -70,6 +70,12 @@ def __call__(self, x): skip = self.down(self.conv3(x)) return (h + skip) / root(2) + @property + def channels(self): + yield self.conv1.in_channels + yield self.conv1.out_channels + yield self.conv2.out_channels + class OutputBlock(Chain): def __init__(self, in_channels, conditional=False, group_size=None): @@ -86,3 +92,9 @@ def __call__(self, x): h1 = self.act1(self.conv1(self.mbstd(x))) h2 = self.act2(self.conv2(h1)) return self.linear(h2) + + @property + def channels(self): + yield self.conv1.in_channels + yield self.conv1.out_channels + yield self.conv2.out_channels diff --git a/stylegan/layers/generator.py b/stylegan/layers/generator.py index 6bcf698..8d3bf86 100644 --- a/stylegan/layers/generator.py +++ b/stylegan/layers/generator.py @@ -56,6 +56,8 @@ class WeightModulatedConvolution(Link): def __init__(self, in_channels, out_channels, pointwise=False, demod=True): super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels self.demod = demod self.ksize = 1 if pointwise else 3 self.pad = 0 if pointwise else 1 @@ -124,6 +126,11 @@ def __call__(self, w, noise=1.0, freeze=None): h4 = self.act(h3) return h4, self.torgb(h4, self.style2(w)) + @property + def channels(self): + yield self.wmconv.in_channels + yield self.wmconv.out_channels + class SkipArchitecture(Chain): def __init__(self, size, in_channels, out_channels, level=2): @@ -151,3 +158,9 @@ def __call__(self, x, y, w, noise=1.0, freeze=None): h6 = self.noise2(h5, coefficient=noise, freeze=freeze) h7 = self.act2(h6) return h7, self.skip(y) + self.torgb(h7, self.style3(w)) + + @property + def channels(self): + yield self.wmconv1.in_channels + yield self.wmconv1.out_channels + yield self.wmconv2.out_channels diff --git a/stylegan/networks.py b/stylegan/networks.py index 8201c38..2604f37 100644 --- a/stylegan/networks.py +++ b/stylegan/networks.py @@ -44,6 +44,12 @@ def __call__(self, ws, noise=1.0, freeze=None): h, rgb = s(h, rgb, w, noise=noise, freeze=freeze) return rgb + @property + def blocks(self): + yield 1, self.init + for i, s in enumerate(self.skips, 2): + yield i, s + class Generator(Chain): def __init__(self, size=512, depth=8, levels=7, first_channels=512, last_channels=64, categories=1): @@ -182,3 +188,9 @@ def __call__(self, x, c=None): h = self.main(x) batch, channels = h.shape return h.reshape(batch) if c is None else sum(h * c1, axis=1) / root(channels) + + @property + def blocks(self): + for i, s in enumerate(self.main): + if i > 0: + yield i, s diff --git a/train.py b/train.py index 3308c7d..ad10776 100644 --- a/train.py +++ b/train.py @@ -7,7 +7,7 @@ from stylegan.augmentation import AugmentationPipeline from interface.args import dump_json, CustomArgumentParser from interface.argtypes import uint, natural, ufloat, positive, rate -from interface.stdout import chainer_like_tqdm, print_model_args, print_parameter_counts, print_data_classes, print_training_args +from interface.stdout import chainer_like_tqdm, print_model_args, print_parameter_counts, print_cnn_architecture, print_data_classes, print_training_args from utilities.stdio import eprint from utilities.filesys import mkdirs, build_filepath @@ -25,6 +25,7 @@ def main(args): averaged_generator.to_device(args.device) print_model_args(generator) print_parameter_counts(generator, discriminator) + print_cnn_architecture(generator, discriminator) optimizers = AdamSet(args.alpha, args.betas[0], args.betas[1], categories > 1) optimizers.setup(generator, discriminator) print("Preparing a dataset...") diff --git a/visualize.py b/visualize.py index 2c77270..bededa3 100644 --- a/visualize.py +++ b/visualize.py @@ -4,7 +4,7 @@ from stylegan.networks import Generator, Discriminator from interface.args import CustomArgumentParser from interface.argtypes import natural -from interface.stdout import print_model_args, print_parameter_counts +from interface.stdout import print_model_args, print_parameter_counts, print_cnn_architecture from utilities.stdio import eprint from utilities.filesys import mkdirs, build_filepath @@ -24,6 +24,7 @@ def main(args): discriminator.to_device(args.device) print_model_args(generator) print_parameter_counts(generator, discriminator) + print_cnn_architecture(generator, discriminator) z = generator.generate_latents(args.batch) c = generator.generate_conditions(args.batch) if args.categories > 1 else None _, x = generator(z, c)