-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
20 changed files
with
724 additions
and
512 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,126 +1,64 @@ | ||
#from json import dump | ||
#from shutil import rmtree | ||
#from argparse import ArgumentParser | ||
import numpy as np | ||
from tqdm import tqdm | ||
from chainer import global_config | ||
from chainer import global_config, Variable | ||
from stylegan.networks import Generator | ||
from interface.args import CustomArgumentParser | ||
#from modules.argtypes import uint, natural, ufloat, positive, rate, filename, device | ||
from interface.stdout import chainer_like_tqdm | ||
from utilities.image import save_image | ||
from utilities.stdio import eprint | ||
from utilities.filesys import mkdirs, build_filepath | ||
from utilities.iter import range_batch | ||
|
||
# Parse command line arguments | ||
parser = CustomArgumentParser("") | ||
parser.add_output_args("images").add_model_args().add_evaluation_args().add_generation_args() | ||
#parser.add_argument("-q", "--quit", action="store_true", help="") | ||
#parser.add_argument("-f", "--force", action="store_true", help="allow overwrite existing files") | ||
#parser.add_argument("-w", "--wipe", action="store_true", help="") | ||
#parser.add_argument("-j", "--dump-json", action="store_true", help="") | ||
#parser.add_argument("-i", "--image-only", action="store_true", help="") | ||
''' | ||
parser.add_argument("-r", "--result", "-d", "--directory", metavar="DEST", dest="directory", default="images", help="destination directory for generated images") | ||
parser.add_argument("-p", "--prefix", type=filename, default="", help="filename prefix for generated images") | ||
parser.add_argument("-g", "--generator", metavar="FILE", help="HDF5 file of serialized trained model to load") | ||
parser.add_argument("-s", "--stage", type=int, choices=[1, 2, 3, 4, 5, 6, 7, 8, 9], default=7, help="growth stage, defining image resolution") | ||
parser.add_argument("-x", "--max-stage", dest="maxstage", type=int, choices=[1, 2, 3, 4, 5, 6, 7, 8, 9], default=7, help="final stage") | ||
parser.add_argument("-c", "--channels", metavar="CH", type=natural, nargs=2, default=(512, 16), help="numbers of channels at initial stage and final stage") | ||
parser.add_argument("-z", "--z-size", dest="size", type=natural, default=512, help="latent vector (feature vector) size") | ||
parser.add_argument("-m", "--mlp-depth", metavar="DEPTH", dest="depth", type=natural, default=8, help="MLP depth of mapping network") | ||
parser.add_argument("-n", "--number", type=uint, default=1, help="the number of images to generate") | ||
parser.add_argument("-b", "--batch", type=natural, default=1, help="batch size, affecting memory usage") | ||
parser.add_argument("-a", "--alpha", type=rate, default=1.0, help="") | ||
parser.add_argument("-l", "--latent", "--center", metavar="FILE", dest="center", help="") | ||
parser.add_argument("-e", "--deviation", "--sd", metavar="SIGMA", dest="sd", type=positive, default=1.0, help="") | ||
parser.add_argument("-t", "--truncation-trick", "--psi", metavar="PSI", dest="psi", type=ufloat, help="") | ||
parser.add_argument("-v", "--device", "--gpu", metavar="ID", dest="device", type=device, default=-1, help="use specified GPU or CPU device") | ||
''' | ||
args = parser.parse_args() | ||
def main(args): | ||
global_config.train = False | ||
global_config.autotune = True | ||
global_config.cudnn_deterministic = True | ||
print("Loading a model...") | ||
generator = Generator.load(args.generator) | ||
generator.to_device(args.device) | ||
if args.center is not None: | ||
print("Loading a latent vector...") | ||
center = Variable(np.load(args.center)) | ||
else: | ||
center = None | ||
if args.classes is not None or args.labels is not None: | ||
if not generator.conditional: | ||
eprint("Unconditional model doesn't have image classes!") | ||
raise RuntimeError("Class error") | ||
categories = [] if args.classes is None else list(args.classes) | ||
categories += [] if args.labels is None else [generator.lookup_label(l) for l in args.labels] | ||
if any(c >= generator.categories for c in categories): | ||
eprint("Some class numbers are not in the valid range!") | ||
raise RuntimeError("Class error") | ||
else: | ||
categories = None | ||
if args.psi != 1.0: | ||
mean_w = generator.calculate_mean_w() | ||
else: | ||
mean_w = None | ||
mkdirs(args.dest) | ||
with chainer_like_tqdm(desc="generation", total=args.number) as bar: | ||
for i, n in range_batch(args.number, args.batch): | ||
z = generator.generate_latents(n, center=center, sd=args.sd) | ||
c = generator.generate_conditions(n, categories=categories) if generator.conditional else None | ||
ws, y = generator(z, c, psi=args.psi, mean_w=mean_w) | ||
z.to_cpu() | ||
y.to_cpu() | ||
ws[0].to_cpu() | ||
for j in range(n): | ||
filename = f"{i + j + 1}" | ||
np.save(build_filepath(args.dest, filename + "-latent", "npy", args.force), z.array[j]) | ||
np.save(build_filepath(args.dest, filename + "-style", "npy", args.force), ws[0].array[j]) | ||
save_image(y.array[j], build_filepath(args.dest, filename, "png", args.force)) | ||
bar.update() | ||
|
||
# Config chainer | ||
global_config.train = False | ||
global_config.autotune = True | ||
global_config.cudnn_deterministic = True | ||
def parse_args(): | ||
parser = CustomArgumentParser("Generate images of a trained generator from random latent vectors") | ||
parser.require_generator().add_output_args("images").add_generation_args().add_evaluation_args() | ||
return parser.parse_args() | ||
|
||
# Init model | ||
print("Initializing model") | ||
generator = Generator.load(args.generator) | ||
|
||
# Print information | ||
#args.stage = args.maxstage if args.stage > args.maxstage else args.stage | ||
#h, w = generator.resolution(args.stage) | ||
#print(f"Total Generation: {args.number}, Batch: {args.batch}") | ||
#print(f"MLP: {args.size}x{args.depth}, Stage: {args.stage}/{args.maxstage} ({w}x{h})") | ||
#print(f"Channel: {args.channels[0]} (initial) -> {args.channels[1]} (final)") | ||
#print(f"Alpha: {args.alpha}, Latent: {'Yes' if args.center is not None else 'No'}, Deviation: {args.sd}") | ||
#print(f"Truncation Trick: {args.psi if args.psi is not None else 'No'}, Device: {'CPU' if args.device < 0 else f'GPU {args.device}'}") | ||
|
||
# GPU setting | ||
generator.to_device(args.device) | ||
|
||
# Load center latent | ||
''' | ||
if args.center is not None: | ||
print("Loading latent") | ||
center = generator.wrap_latent(load_array(args.center)) | ||
else: | ||
center = None | ||
''' | ||
|
||
# Init destination folder | ||
#print("Initializing destination directory") | ||
#if args.wipe: | ||
# rmtree(args.directory, ignore_errors=True) | ||
mkdirs(args.dest) | ||
|
||
''' | ||
# Dump command-line options | ||
if args.dump_json: | ||
path = filepath(args.directory, "args_quit" if args.quit else "args", "json") | ||
path = path if args.force else altfilepath(path) | ||
with open(path, mode="w", encoding="utf-8") as fp: | ||
dump(vars(args), fp, indent=2, sort_keys=True) | ||
# Quit mode | ||
if args.quit: | ||
print("Finished (Quit mode)") | ||
exit(0) | ||
''' | ||
|
||
bf = "{desc} [{bar}] {percentage:5.1f}%" | ||
with tqdm(desc="generation", total=args.number, bar_format=bf, miniters=1, ascii=".#", ncols=70) as bar: | ||
for i, n in range_batch(args.number, args.batch): | ||
#mixing = mix > random() | ||
z = generator.generate_latents(n) | ||
c = generator.generate_conditions(n) | ||
#mix_z = generator.generate_latent(n) if mixing else None | ||
ws, y = generator(z, c) | ||
z.to_cpu() | ||
y.to_cpu() | ||
for j in range(n): | ||
filename = f"{i + j + 1}" | ||
np.save(build_filepath(args.dest, filename, "npy", args.force), z.array[j]) | ||
save_image(y.array[j], build_filepath(args.dest, filename, "png", args.force)) | ||
bar.update() | ||
''' | ||
# Generate images | ||
c = 0 | ||
mean_w = None if args.psi is None else generator.calculate_mean_w() | ||
while c < args.number: | ||
n = min(args.number - c, args.batch) | ||
z = generator.generate_latent(n, center=center, sd=args.sd) | ||
y = generator(z, args.stage, alpha=args.alpha, psi=args.psi, mean_w=mean_w) | ||
z.to_cpu() | ||
y.to_cpu() | ||
for i in range(n): | ||
path = filepath(args.directory, f"{args.prefix}{c + i + 1}", "png") | ||
path = path if args.force else altfilepath(path) | ||
save_image(y.array[i], path) | ||
print(f"{c + i + 1}/{args.number}: Saved as {path}") | ||
if not args.image_only: | ||
path = filepath(args.directory, f"{args.prefix}{c + i + 1}", "npy") | ||
path = path if args.force else altfilepath(path) | ||
save_array(z.array[i], path) | ||
c += n | ||
''' | ||
if __name__ == "__main__": | ||
try: | ||
main(parse_args()) | ||
except KeyboardInterrupt: | ||
eprint("KeyboardInterrupt") | ||
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from tqdm import tqdm | ||
|
||
bar_format = "{desc} [{bar}] {percentage:5.1f}%" | ||
|
||
def chainer_like_tqdm(desc, total): | ||
return tqdm(desc=desc, total=total, bar_format=bar_format, miniters=1, ascii=".#", ncols=70) | ||
|
||
def print_model_args(generator): | ||
h, w = generator.resolution | ||
print(f"Multilayer perceptron: {generator.size}x{generator.depth}") | ||
print(f"CNN layers: {generator.levels} levels (output = {w}x{h})") | ||
print(f"CNN channels: {generator.first_channels} (initial) -> {generator.last_channels} (final)") | ||
|
||
def print_data_classes(generator): | ||
print(f"Data classes: {generator.categories if generator.conditional else '1 (unconditional)'}") | ||
if generator.conditional: | ||
for i, l in enumerate(generator.labels): | ||
print(f"- class {i}: {l}") | ||
|
||
def print_training_args(args): | ||
if args.accum is None: | ||
print(f"Batch size: {args.batch} (Group size: {'entire batch' if args.group == 0 else args.group})") | ||
else: | ||
print(f"Accum/batch size: {args.accum}/{args.batch} (Group size: {'entire accum/batch' if args.group == 0 else args.group})") | ||
print(f"Style-mixing rate: {args.mix * 100}%") | ||
if args.gamma > 0 and args.r1 > 1: | ||
print(f"R1 regularization: coefficient = {args.gamma} (every {args.r1} iterations)") | ||
elif args.gamma > 0: | ||
print(f"R1 regularization: coefficient = {args.gamma} (every iteration)") | ||
else: | ||
print("R1 regularization: disabled") | ||
if args.weight > 0 and args.pl > 1: | ||
print(f"Path length regularization: coefficient = {args.weight}, decay = {args.decay} (every {args.pl} iterations)") | ||
elif args.weight > 0: | ||
print(f"Path length regularization: coefficient = {args.weight}, decay = {args.decay} (every iteration)") | ||
else: | ||
print(f"Path length regularization: disabled") | ||
print(f"Objective: {'least squares loss' if args.lsgan else 'logistic loss'}") | ||
print(f"Adam: alpha = {args.alpha}, beta1 = {args.betas[0]}, beta2 = {args.betas[1]}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from stylegan.networks import Generator | ||
from interface.args import CustomArgumentParser | ||
from interface.stdout import print_model_args, print_data_classes | ||
from utilities.stdio import eprint | ||
|
||
def main(args): | ||
print("Loading a model...") | ||
generator = Generator.load(args.generator) | ||
print_model_args(generator) | ||
print_data_classes(generator) | ||
|
||
def parse_args(): | ||
parser = CustomArgumentParser("Show model arguments and data classes of a serialized generator") | ||
parser.require_generator() | ||
return parser.parse_args() | ||
|
||
if __name__ == "__main__": | ||
try: | ||
main(parse_args()) | ||
except KeyboardInterrupt: | ||
eprint("KeyboardInterrupt") | ||
exit(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from math import sqrt, log | ||
from chainer import Parameter, Link | ||
from chainer.functions import gaussian, leaky_relu, linear, convolution_2d, broadcast_to | ||
from chainer.initializers import Zero, Normal | ||
|
||
class GaussianDistribution(): | ||
|
||
def __init__(self, link, mean=0.0, sd=1.0): | ||
self.link = link | ||
self.mean = mean | ||
self.ln_var = log(sd ** 2) | ||
|
||
def __call__(self, *shape): | ||
mean = self.link.xp.array(self.mean, dtype=self.link.xp.float32) | ||
ln_var = self.link.xp.array(self.ln_var, dtype=self.link.xp.float32) | ||
return gaussian(broadcast_to(mean, shape), broadcast_to(ln_var, shape)) | ||
|
||
class LeakyRelu(): | ||
|
||
def __init__(self, a=0.2): | ||
self.a = a | ||
|
||
def __call__(self, x): | ||
return leaky_relu(x, self.a) | ||
|
||
class EqualizedLinear(Link): | ||
|
||
def __init__(self, in_size, out_size, nobias=False, initial_bias=Zero(), gain=sqrt(2)): | ||
super().__init__() | ||
self.c = gain / sqrt(in_size) | ||
with self.init_scope(): | ||
self.w = Parameter(shape=(out_size, in_size), initializer=Normal(1.0)) | ||
self.b = None if nobias else Parameter(shape=out_size, initializer=initial_bias) | ||
|
||
def __call__(self, x): | ||
return linear(self.c * x, self.w, self.b) | ||
|
||
class EqualizedConvolution2D(Link): | ||
|
||
def __init__(self, in_channels, out_channels, ksize=3, stride=1, pad=0, nobias=False, initial_bias=Zero(), gain=sqrt(2)): | ||
super().__init__() | ||
self.stride = stride | ||
self.pad = pad | ||
self.c = gain / sqrt(in_channels * ksize ** 2) | ||
with self.init_scope(): | ||
self.w = Parameter(shape=(out_channels, in_channels, ksize, ksize), initializer=Normal(1.0)) | ||
self.b = None if nobias else Parameter(shape=out_channels, initializer=initial_bias) | ||
|
||
def __call__(self, x): | ||
return convolution_2d(self.c * x, self.w, self.b, self.stride, self.pad) |
Oops, something went wrong.