From 07474586542393338ef036e33fc8a0a985144238 Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Sat, 31 Dec 2022 20:43:46 +0900 Subject: [PATCH 1/5] all --- train.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 5165f98..7e3e31c 100755 --- a/train.py +++ b/train.py @@ -153,7 +153,9 @@ def preprocess_args(args): if args.transfer is not None: try: snapshot, g, d = args.transfer - args.transfer = snapshot, uint(g), uint(d) + g = 1 if g.lower() == "all" else uint(g) + d = args.levels if d.lower() == "all" else uint(d) + args.transfer = snapshot, g, d except: eprint("Transfer levels must be non-negative integers!") raise @@ -167,7 +169,7 @@ def parse_args(): parser.add_argument("-l", "--labels", metavar="CLASS", nargs="*", help="embed data class labels into output generators (provide CLASS as many as dataset directories), dataset directory names are automatically used if no CLASS arguments are given") group = parser.add_argument_group("training arguments") group.add_argument("-s", "--snapshot", metavar="HDF5_FILE", help="load weights and parameters from a snapshot (for resuming)") - group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "G", "D"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks only above/below level G/D (inclusive)") + group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "{G|all}", "{D|all}"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks only above/below level G/D (inclusive), or specify 'all' to transfer all blocks") group.add_argument("-Z", "--freeze", metavar=("G", "D"), nargs=2, type=uint, help="disable updating generator/discriminator CNN blocks above/below level G/D (inclusive), likely used with --transfer") group.add_argument("-e", "--epoch", metavar="N", type=uint, default=1, help="training duration in epoch (note that elapsed training duration will not be serialized in snapshot)") group.add_argument("-b", "--batch", metavar="N", type=natural, default=16, help="batch size, affecting not only memory usage, but also training result") From 712be7fe3dd5ef67ed74eef10b1563380570f28c Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Sun, 22 Jan 2023 23:55:09 +0900 Subject: [PATCH 2/5] Change exit status SIGINT --- animate.py | 2 +- combine.py | 2 +- generate.py | 2 +- mix.py | 2 +- show.py | 2 +- train.py | 2 +- visualize.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/animate.py b/animate.py index a7b9346..49a96a6 100755 --- a/animate.py +++ b/animate.py @@ -129,4 +129,4 @@ def parse_args(): main(check_args(parse_args())) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/combine.py b/combine.py index 2479357..09c88a3 100755 --- a/combine.py +++ b/combine.py @@ -55,4 +55,4 @@ def parse_args(): main(check_args(preprocess_args(parse_args()))) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/generate.py b/generate.py index f2c786f..19e6838 100755 --- a/generate.py +++ b/generate.py @@ -61,4 +61,4 @@ def parse_args(): main(parse_args()) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/mix.py b/mix.py index 6cdcc2a..e77aba0 100755 --- a/mix.py +++ b/mix.py @@ -54,4 +54,4 @@ def parse_args(): main(parse_args()) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/show.py b/show.py index 7fd3083..2dd83b1 100755 --- a/show.py +++ b/show.py @@ -23,4 +23,4 @@ def parse_args(): main(parse_args()) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/train.py b/train.py index 7e3e31c..c4b1f28 100755 --- a/train.py +++ b/train.py @@ -208,4 +208,4 @@ def parse_args(): main(check_args(preprocess_args(parse_args()))) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) diff --git a/visualize.py b/visualize.py index aa81ba2..2be5f68 100755 --- a/visualize.py +++ b/visualize.py @@ -52,4 +52,4 @@ def parse_args(): main(parse_args()) except KeyboardInterrupt: eprint("KeyboardInterrupt") - exit(1) + exit(130) From 9a00bd0084ba44629974eeee9b9ca52fbae816f1 Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Sun, 29 Jan 2023 21:26:08 +0900 Subject: [PATCH 3/5] fix --- stylegan/networks.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/stylegan/networks.py b/stylegan/networks.py index 3651abb..2b4ca53 100644 --- a/stylegan/networks.py +++ b/stylegan/networks.py @@ -233,10 +233,11 @@ def freeze(self, levels=[]): b.conv2.disable_update() def transfer(self, source, levels=[]): - for (i, dest), (_, src) in zip(self.blocks, source.blocks): + for (i, dest), (_, src) in zip(reversed(list(self.blocks)), reversed(list(source.blocks))): if i in levels: if isinstance(dest, FromRGB): - dest.copyparams(src) + pass + #dest.copyparams(src) elif isinstance(dest, ResidualBlock): dest.copyparams(src) else: From 7b10746cca5336b90b6ecf2370f5b5a244b614bd Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Sun, 12 Feb 2023 23:24:40 +0900 Subject: [PATCH 4/5] Fix --- stylegan/networks.py | 34 ++++++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 8 deletions(-) diff --git a/stylegan/networks.py b/stylegan/networks.py index 2b4ca53..6dfe4d6 100644 --- a/stylegan/networks.py +++ b/stylegan/networks.py @@ -144,21 +144,29 @@ def freeze(self, levels=[]): if isinstance(b, InitialSkipArchitecture): b.wmconv.disable_update() b.torgb.disable_update() - else: + elif isinstance(b, SkipArchitecture): b.wmconv1.disable_update() b.wmconv2.disable_update() b.torgb.disable_update() def transfer(self, source, levels=[]): - for (i, dest), (_, src) in zip(self.synthesizer.blocks, source.synthesizer.blocks): + for (i, dest), (j, src) in zip(reversed(list(self.synthesizer.blocks)), reversed(list(source.synthesizer.blocks))): if i in levels: if isinstance(dest, InitialSkipArchitecture): + if not isinstance(src, InitialSkipArchitecture): + eprint("Network architecture doesn't match!") + raise RuntimeError("Model error") dest.wmconv.copyparams(src.wmconv) dest.torgb.copyparams(src.torgb) - else: + elif isinstance(dest, SkipArchitecture): + if not isinstance(src, SkipArchitecture): + eprint("Network architecture doesn't match!") + raise RuntimeError("Model error") dest.wmconv1.copyparams(src.wmconv1) dest.wmconv2.copyparams(src.wmconv2) dest.torgb.copyparams(src.torgb) + else: + raise RuntimeError() def save(self, filepath): with HDF5File(filepath, "w") as hdf5: @@ -228,21 +236,31 @@ def freeze(self, levels=[]): b.disable_update() elif isinstance(b, ResidualBlock): b.disable_update() - else: + elif isinstance(b, OutputBlock): b.conv1.disable_update() b.conv2.disable_update() def transfer(self, source, levels=[]): - for (i, dest), (_, src) in zip(reversed(list(self.blocks)), reversed(list(source.blocks))): + for (i, dest), (j, src) in zip(self.blocks, source.blocks): if i in levels: if isinstance(dest, FromRGB): - pass - #dest.copyparams(src) + if not isinstance(src, FromRGB): + eprint("Network architecture doesn't match!") + raise RuntimeError("Model error") + dest.copyparams(src) elif isinstance(dest, ResidualBlock): + if not isinstance(src, ResidualBlock): + eprint("Network architecture doesn't match!") + raise RuntimeError("Model error") dest.copyparams(src) - else: + elif isinstance(dest, OutputBlock): + if not isinstance(src, OutputBlock): + eprint("Network architecture doesn't match!") + raise RuntimeError("Model error") dest.conv1.copyparams(src.conv1) dest.conv2.copyparams(src.conv2) + else: + raise RuntimeError() @property def blocks(self): From 37cefa7236a7d3dd1d98805f84aa119dcfbe3624 Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Mon, 13 Feb 2023 00:03:14 +0900 Subject: [PATCH 5/5] Fix help --- train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train.py b/train.py index 91203ae..c930b31 100755 --- a/train.py +++ b/train.py @@ -176,7 +176,7 @@ def parse_args(): parser.add_argument("-l", "--labels", metavar="CLASS", nargs="*", help="embed data class labels into output generators (provide CLASS as many as dataset directories), dataset directory names are automatically used if no CLASS arguments are given") group = parser.add_argument_group("training arguments") group.add_argument("-s", "--snapshot", metavar="HDF5_FILE", help="load weights and parameters from a snapshot (for resuming)") - group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "{G|all}", "{D|all}"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks only above/below level G/D (inclusive), or specify 'all' to transfer all blocks") + group.add_argument("-t", "--transfer", metavar=("HDF5_FILE", "{G|all}", "{D|all}"), nargs=3, help="import CNN weights from another snapshot (transfer learning), transfer generator/discriminator CNN blocks to levels only above/below level G/D (inclusive) from corresponding ones aligned from the top/bottom level, specify 'all' to transfer all blocks") group.add_argument("-Z", "--freeze", metavar=("G", "D"), nargs=2, type=uint, help="disable updating generator/discriminator CNN blocks above/below level G/D (inclusive), likely used with --transfer") group.add_argument("-e", "--epoch", metavar="N", type=uint, default=1, help="training duration in epoch (note that elapsed training duration will not be serialized in snapshot)") group.add_argument("-b", "--batch", metavar="N", type=natural, default=16, help="batch size, affecting not only memory usage, but also training result")