Skip to content

Commit

Permalink
Embed model parameters in hdf5
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed May 13, 2021
1 parent 4994e78 commit dad4069
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 14 deletions.
7 changes: 1 addition & 6 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

# Init model
print("Initializing model")
generator = Generator(args.size, args.depth, args.levels, *args.channels)
generator = Generator.load(args.generator)

# Print information
#args.stage = args.maxstage if args.stage > args.maxstage else args.stage
Expand All @@ -56,11 +56,6 @@
#print(f"Alpha: {args.alpha}, Latent: {'Yes' if args.center is not None else 'No'}, Deviation: {args.sd}")
#print(f"Truncation Trick: {args.psi if args.psi is not None else 'No'}, Device: {'CPU' if args.device < 0 else f'GPU {args.device}'}")

# Load model
if args.generator is not None:
print("Loading generator")
generator.load_weights(args.generator)

# GPU setting
generator.to_device(args.device)

Expand Down
31 changes: 25 additions & 6 deletions stylegan/networks.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from math import sqrt as root
from random import randint
from h5py import File as HDF5File
from chainer import Chain, ChainList, Sequential
from chainer.functions import sqrt, mean
from chainer.serializers import load_hdf5, save_hdf5
from chainer.serializers import HDF5Serializer, HDF5Deserializer
from stylegan.links.common import GaussianDistribution, EqualizedLinear, LeakyRelu
from stylegan.links.generator import InitialSkipArchitecture, SkipArchitecture
from stylegan.links.discriminator import FromRGB, ResidualBlock, OutputBlock
Expand Down Expand Up @@ -43,7 +44,10 @@ class Generator(Chain):
def __init__(self, size=512, depth=8, levels=7, first_channels=512, last_channels=64):
super().__init__()
self.size = size
self.depth = depth
self.levels = levels
self.first_channels = first_channels
self.last_channels = last_channels
self.resolution = (2 * 2 ** levels, 2 * 2 ** levels)
with self.init_scope():
self.sampler = GaussianDistribution()
Expand Down Expand Up @@ -78,11 +82,26 @@ def generate_masks(self, batch):
def calculate_mean_w(self, n=50000):
return mean(self.mapper(self.generate_latents(n)), axis=0)

def load_weights(self, filepath):
load_hdf5(filepath, self)

def save_weights(self, filepath):
save_hdf5(filepath, self)
def save(self, filepath):
with HDF5File(filepath, "w") as hdf5:
hdf5.create_dataset("size", data=self.size)
hdf5.create_dataset("depth", data=self.depth)
hdf5.create_dataset("levels", data=self.levels)
hdf5.create_dataset("first_channels", data=self.first_channels)
hdf5.create_dataset("last_channels", data=self.last_channels)
HDF5Serializer(hdf5.create_group("weights")).save(self)

@staticmethod
def load(filepath):
with HDF5File as hdf5:
size = int(hdf5["size"][()])
depth = int(hdf5["depth"][()])
levels = int(hdf5["levels"][()])
first_channels = int(hdf5["size"][()])
last_channels = int(hdf5["size"][()])
generator = Generator(size, depth, levels, first_channels, last_channels)
HDF5Deserializer(hdf5["weights"]).load(generator)
return generator

class Discriminator(Chain):

Expand Down
2 changes: 1 addition & 1 deletion stylegan/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def save_snapshot(trainer):
@staticmethod
def save_generator(trainer):
filepath = build_filepath(trainer.states_out, f"generator-{trainer.iteration}", "hdf5", trainer.overwrite)
trainer.updater.averaged_generator.save_weights(filepath)
trainer.updater.averaged_generator.save(filepath)

@staticmethod
def save_images(trainer):
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main(args):
trainer.enable_reports(500)
trainer.enable_progress_bar(1)
trainer.run()
averaged_generator.save_weights(build_filepath(args.dest, "generator", "hdf5", args.force))
averaged_generator.save(build_filepath(args.dest, "generator", "hdf5", args.force))
updater.save_states(build_filepath(args.dest, "snapshot", "hdf5", args.force))

def parse_args():
Expand Down

0 comments on commit dad4069

Please sign in to comment.