Skip to content

Commit

Permalink
Show CNN architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Sep 14, 2022
1 parent 6e58aaa commit 9a9dd02
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 3 deletions.
14 changes: 14 additions & 0 deletions interface/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ def print_parameter_counts(generator, discriminator=None):
print(f"- G: {generator.count_params()}")
print(f"- D: {discriminator.count_params()}")

def print_cnn_architecture(generator, discriminator=None):
if discriminator is None:
print("CNN channels:")
else:
print("Generator CNN channels:")
pad = max(max(len(str(s)) for s in s.channels) for _, s in generator.synthesizer.blocks)
for i, s in generator.synthesizer.blocks:
print(f"- Level {i}: " + " -> conv -> ".join(str(c).rjust(pad) for c in s.channels))
if discriminator is not None:
print("Discriminator CNN channels:")
pad = max(max(len(str(b)) for b in b.channels) for _, b in discriminator.blocks)
for i, b in discriminator.blocks:
print(f"- Level {i}: " + " -> conv -> ".join(str(b).rjust(pad) for b in b.channels))

def print_training_args(args):
if args.accum is None:
print(f"Batch size: {args.batch} (Group size: {'entire batch' if args.group == 0 else args.group})")
Expand Down
3 changes: 2 additions & 1 deletion show.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from stylegan.networks import Generator
from interface.args import CustomArgumentParser
from interface.stdout import print_model_args, print_parameter_counts, print_data_classes
from interface.stdout import print_model_args, print_parameter_counts, print_cnn_architecture, print_data_classes
from utilities.stdio import eprint

def main(args):
print("Loading a model...")
generator = Generator.load(args.generator)
print_model_args(generator)
print_parameter_counts(generator)
print_cnn_architecture(generator)
print_data_classes(generator)

def parse_args():
Expand Down
2 changes: 2 additions & 0 deletions stylegan/layers/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class EqualizedConvolution2D(Link):

def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=0, nobias=False, initial_bias=Zero(), gain=sqrt(2)):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.stride = stride
self.pad = pad
self.c = gain / sqrt(in_channels * ksize ** 2)
Expand Down
12 changes: 12 additions & 0 deletions stylegan/layers/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def __call__(self, x):
skip = self.down(self.conv3(x))
return (h + skip) / root(2)

@property
def channels(self):
yield self.conv1.in_channels
yield self.conv1.out_channels
yield self.conv2.out_channels

class OutputBlock(Chain):

def __init__(self, in_channels, conditional=False, group_size=None):
Expand All @@ -86,3 +92,9 @@ def __call__(self, x):
h1 = self.act1(self.conv1(self.mbstd(x)))
h2 = self.act2(self.conv2(h1))
return self.linear(h2)

@property
def channels(self):
yield self.conv1.in_channels
yield self.conv1.out_channels
yield self.conv2.out_channels
13 changes: 13 additions & 0 deletions stylegan/layers/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class WeightModulatedConvolution(Link):

def __init__(self, in_channels, out_channels, pointwise=False, demod=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.demod = demod
self.ksize = 1 if pointwise else 3
self.pad = 0 if pointwise else 1
Expand Down Expand Up @@ -124,6 +126,11 @@ def __call__(self, w, noise=1.0, freeze=None):
h4 = self.act(h3)
return h4, self.torgb(h4, self.style2(w))

@property
def channels(self):
yield self.wmconv.in_channels
yield self.wmconv.out_channels

class SkipArchitecture(Chain):

def __init__(self, size, in_channels, out_channels, level=2):
Expand Down Expand Up @@ -151,3 +158,9 @@ def __call__(self, x, y, w, noise=1.0, freeze=None):
h6 = self.noise2(h5, coefficient=noise, freeze=freeze)
h7 = self.act2(h6)
return h7, self.skip(y) + self.torgb(h7, self.style3(w))

@property
def channels(self):
yield self.wmconv1.in_channels
yield self.wmconv1.out_channels
yield self.wmconv2.out_channels
12 changes: 12 additions & 0 deletions stylegan/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def __call__(self, ws, noise=1.0, freeze=None):
h, rgb = s(h, rgb, w, noise=noise, freeze=freeze)
return rgb

@property
def blocks(self):
yield 1, self.init
for i, s in enumerate(self.skips, 2):
yield i, s

class Generator(Chain):

def __init__(self, size=512, depth=8, levels=7, first_channels=512, last_channels=64, categories=1):
Expand Down Expand Up @@ -182,3 +188,9 @@ def __call__(self, x, c=None):
h = self.main(x)
batch, channels = h.shape
return h.reshape(batch) if c is None else sum(h * c1, axis=1) / root(channels)

@property
def blocks(self):
for i, s in enumerate(self.main):
if i > 0:
yield i, s
3 changes: 2 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from stylegan.augmentation import AugmentationPipeline
from interface.args import dump_json, CustomArgumentParser
from interface.argtypes import uint, natural, ufloat, positive, rate
from interface.stdout import chainer_like_tqdm, print_model_args, print_parameter_counts, print_data_classes, print_training_args
from interface.stdout import chainer_like_tqdm, print_model_args, print_parameter_counts, print_cnn_architecture, print_data_classes, print_training_args
from utilities.stdio import eprint
from utilities.filesys import mkdirs, build_filepath

Expand All @@ -25,6 +25,7 @@ def main(args):
averaged_generator.to_device(args.device)
print_model_args(generator)
print_parameter_counts(generator, discriminator)
print_cnn_architecture(generator, discriminator)
optimizers = AdamSet(args.alpha, args.betas[0], args.betas[1], categories > 1)
optimizers.setup(generator, discriminator)
print("Preparing a dataset...")
Expand Down
3 changes: 2 additions & 1 deletion visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from stylegan.networks import Generator, Discriminator
from interface.args import CustomArgumentParser
from interface.argtypes import natural
from interface.stdout import print_model_args, print_parameter_counts
from interface.stdout import print_model_args, print_parameter_counts, print_cnn_architecture
from utilities.stdio import eprint
from utilities.filesys import mkdirs, build_filepath

Expand All @@ -24,6 +24,7 @@ def main(args):
discriminator.to_device(args.device)
print_model_args(generator)
print_parameter_counts(generator, discriminator)
print_cnn_architecture(generator, discriminator)
z = generator.generate_latents(args.batch)
c = generator.generate_conditions(args.batch) if args.categories > 1 else None
_, x = generator(z, c)
Expand Down

0 comments on commit 9a9dd02

Please sign in to comment.