Skip to content

Commit

Permalink
Merge branch 'transfer'
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Feb 12, 2023
2 parents 4bd2300 + 5d2b7d8 commit 597c509
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 8 deletions.
31 changes: 25 additions & 6 deletions stylegan/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,29 @@ def freeze(self, levels=[]):
if isinstance(b, InitialSkipArchitecture):
b.wmconv.disable_update()
b.torgb.disable_update()
else:
elif isinstance(b, SkipArchitecture):
b.wmconv1.disable_update()
b.wmconv2.disable_update()
b.torgb.disable_update()

def transfer(self, source, levels=[]):
for (i, dest), (_, src) in zip(self.synthesizer.blocks, source.synthesizer.blocks):
for (i, dest), (j, src) in zip(reversed(list(self.synthesizer.blocks)), reversed(list(source.synthesizer.blocks))):
if i in levels:
if isinstance(dest, InitialSkipArchitecture):
if not isinstance(src, InitialSkipArchitecture):
eprint("Network architecture doesn't match!")
raise RuntimeError("Model error")
dest.wmconv.copyparams(src.wmconv)
dest.torgb.copyparams(src.torgb)
else:
elif isinstance(dest, SkipArchitecture):
if not isinstance(src, SkipArchitecture):
eprint("Network architecture doesn't match!")
raise RuntimeError("Model error")
dest.wmconv1.copyparams(src.wmconv1)
dest.wmconv2.copyparams(src.wmconv2)
dest.torgb.copyparams(src.torgb)
else:
raise RuntimeError()

def save(self, filepath):
with HDF5File(filepath, "w") as hdf5:
Expand Down Expand Up @@ -228,20 +236,31 @@ def freeze(self, levels=[]):
b.disable_update()
elif isinstance(b, ResidualBlock):
b.disable_update()
else:
elif isinstance(b, OutputBlock):
b.conv1.disable_update()
b.conv2.disable_update()

def transfer(self, source, levels=[]):
for (i, dest), (_, src) in zip(self.blocks, source.blocks):
for (i, dest), (j, src) in zip(self.blocks, source.blocks):
if i in levels:
if isinstance(dest, FromRGB):
if not isinstance(src, FromRGB):
eprint("Network architecture doesn't match!")
raise RuntimeError("Model error")
dest.copyparams(src)
elif isinstance(dest, ResidualBlock):
if not isinstance(src, ResidualBlock):
eprint("Network architecture doesn't match!")
raise RuntimeError("Model error")
dest.copyparams(src)
else:
elif isinstance(dest, OutputBlock):
if not isinstance(src, OutputBlock):
eprint("Network architecture doesn't match!")
raise RuntimeError("Model error")
dest.conv1.copyparams(src.conv1)
dest.conv2.copyparams(src.conv2)
else:
raise RuntimeError()

@property
def blocks(self):
Expand Down
6 changes: 4 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def preprocess_args(args):
if args.transfer is not None:
try:
snapshot, g, d = args.transfer
args.transfer = snapshot, uint(g), uint(d)
g = 1 if g.lower() == "all" else uint(g)
d = args.levels if d.lower() == "all" else uint(d)
args.transfer = snapshot, g, d
except:
eprint("Transfer levels must be non-negative integers!")
raise
Expand All @@ -174,7 +176,7 @@ def parse_args():
parser.add_argument("-l", "--labels", metavar="CLASS", nargs="*", help="embed data class labels into output generators (provide CLASS as many as dataset directories), dataset directory names are automatically used if no CLASS arguments are given")
group = parser.add_argument_group("training arguments")
group.add_argument("-s", "--snapshot", metavar="HDF5_FILE", help="load weights and parameters from a snapshot (for resuming)")
group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "G", "D"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks only above/below level G/D (inclusive)")
group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "{G|all}", "{D|all}"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks to levels only above/below level G/D (inclusive) from corresponding ones aligned from the top/bottom level, specify 'all' to transfer all blocks")
group.add_argument("-Z", "--freeze", metavar=("G", "D"), nargs=2, type=uint, help="disable updating generator/discriminator CNN blocks above/below level G/D (inclusive), likely used with --transfer")
group.add_argument("-e", "--epoch", metavar="N", type=uint, default=1, help="training duration in epoch (note that elapsed training duration will not be serialized in snapshot)")
group.add_argument("-b", "--batch", metavar="N", type=natural, default=16, help="batch size, affecting not only memory usage, but also training result")
Expand Down

0 comments on commit 597c509

Please sign in to comment.