-
Notifications
You must be signed in to change notification settings - Fork 0
/
edit.py
122 lines (103 loc) · 5.21 KB
/
edit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# python3.7
"""Edits latent codes with respect to given neural net.
Basically, this file takes latent codes and a latent classifier as inputs, and
then shows how the image synthesis will change if the latent codes is moved
towards the given classifiers gradient.
NOTE: If you want to use W or W+ space of StyleGAN, please do not randomly
sample the latent code, since neither W nor W+ space is subject to Gaussian
distribution. Instead, please use `generate_data.py` to get the latent vectors
from W or W+ space first, and then use `--input_latent_codes_path` option to
pass in the latent vectors.
"""
import os.path
import argparse
import cv2
import numpy as np
from tqdm import tqdm
import torch
from models.model_settings import MODEL_POOL
from models.pggan_generator import PGGANGenerator
from models.stylegan_generator import StyleGANGenerator
from models.stylegan2_generator import StyleGAN2Generator
from utils.logger import setup_logger
from utils.nl_manipulator import nonlinear_interpolate
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(
description='Edit image synthesis with given semantic boundary.')
parser.add_argument('-m', '--model_name', type=str, required=True,
choices=list(MODEL_POOL),
help='Name of the model for generation. (required)')
parser.add_argument('-o', '--output_dir', type=str, required=True,
help='Directory to save the output results. (required)')
parser.add_argument('-l', '--latent_classifier', required=True,
help='path to pre-trained Latent classifier .PKL file for generating labels and gradients.')
parser.add_argument('-i', '--input_latent_codes_path', type=str, default='',
help='If specified, will load latent codes from given '
'path instead of randomly sampling. (optional)')
parser.add_argument('-n', '--num', type=int, default=1,
help='Number of images for editing. This field will be '
'ignored if `input_latent_codes_path` is specified. '
'(default: 1)')
parser.add_argument('-s', '--latent_space_type', type=str, default='z',
choices=['z', 'Z', 'w', 'W', 'wp', 'wP', 'Wp', 'WP'],
help='Latent space used in Style GAN. (default: `Z`)')
parser.add_argument('--end_distance', type=float, default=3.0,
help='End point for manipulation in latent space. '
'(default: 3.0)')
parser.add_argument('--steps', type=int, default=10,
help='Number of steps for image editing. (default: 10)')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
logger = setup_logger(args.output_dir, logger_name='generate_data')
logger.info(f'Initializing generator.')
gan_type = MODEL_POOL[args.model_name]['gan_type']
if gan_type == 'pggan':
model = PGGANGenerator(args.model_name, logger)
kwargs = {}
elif gan_type == 'stylegan':
model = StyleGANGenerator(args.model_name, logger)
kwargs = {'latent_space_type': args.latent_space_type}
elif gan_type == 'stylegan2':
model = StyleGAN2Generator(args.model_name, logger)
kwargs = {'latent_space_type': args.latent_space_type}
else:
raise NotImplementedError(f'Not implemented GAN type `{gan_type}`!')
logger.info(f'Preparing classifier.')
if not os.path.isfile(args.latent_classifier):
raise ValueError(f'Neural net latent classifier pickle `{args.latent_classifier}` does not exist!')
lclass=torch.load(args.latent_classifier)
logger.info(f'Preparing latent codes.')
if os.path.isfile(args.input_latent_codes_path):
logger.info(f' Load latent codes from `{args.input_latent_codes_path}`.')
latent_codes = np.load(args.input_latent_codes_path)
latent_codes = model.preprocess(latent_codes, **kwargs)
else:
logger.info(f' Sample latent codes randomly.')
latent_codes = model.easy_sample(args.num, **kwargs)
np.save(os.path.join(args.output_dir, 'latent_codes.npy'), latent_codes)
total_num = latent_codes.shape[0]
logger.info(f'Editing {total_num} samples.')
for sample_id in tqdm(range(total_num), leave=False):
interpolations = nonlinear_interpolate(latent_codes[sample_id:sample_id + 1],
lclass,
end_distance=args.end_distance,
steps=args.steps)
interpolation_id = 0
for interpolations_batch in model.get_batch_inputs(interpolations):
if gan_type == 'pggan':
outputs = model.easy_synthesize(interpolations_batch)
elif gan_type == 'stylegan' or gan_type == 'stylegan2':
outputs = model.easy_synthesize(interpolations_batch, **kwargs)
for image in outputs['image']:
save_path = os.path.join(args.output_dir,
f'{sample_id:03d}_{interpolation_id:03d}.jpg')
cv2.imwrite(save_path, image[:, :, ::-1])
interpolation_id += 1
assert interpolation_id == args.steps
logger.debug(f' Finished sample {sample_id:3d}.')
logger.info(f'Successfully edited {total_num} samples.')
if __name__ == '__main__':
main()