From 6e58aaae669ecfaef486a9adac02720fcdb9f93d Mon Sep 17 00:00:00 2001 From: curegit <37978051+curegit@users.noreply.github.com> Date: Tue, 13 Sep 2022 15:39:59 +0900 Subject: [PATCH] Fix ugly logic & code for case insensitive glob --- stylegan/dataset.py | 4 ++-- utilities/filesys.py | 39 +++++++++++++++++---------------------- utilities/iter.py | 10 ---------- 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/stylegan/dataset.py b/stylegan/dataset.py index 0c4a9a8..9a97fc3 100644 --- a/stylegan/dataset.py +++ b/stylegan/dataset.py @@ -3,7 +3,7 @@ from chainer.dataset import DatasetMixin from utilities.stdio import eprint from utilities.image import load_image -from utilities.filesys import glob_recursively +from utilities.filesys import relaxed_glob_recursively class ImageDataset(DatasetMixin): @@ -21,7 +21,7 @@ def __init__(self, directory, resolution): eprint(f"Invalid dataset: {directory}") eprint("Specified path is not a correct directory!") raise RuntimeError("Input error") - self.image_files = sum([glob_recursively(directory, e, robust_letter_case=True) for e in ImageDataset.extensions], []) + self.image_files = sum([relaxed_glob_recursively(directory, e) for e in ImageDataset.extensions], []) if not self.image_files: eprint(f"Invalid dataset: {directory}") eprint("No images found in the directory!") diff --git a/utilities/filesys.py b/utilities/filesys.py index d7df95d..f452819 100644 --- a/utilities/filesys.py +++ b/utilities/filesys.py @@ -1,7 +1,6 @@ import os import os.path import glob -from utilities.iter import dict_groupby def mkdirs(dirpath): os.makedirs(os.path.normpath(dirpath), exist_ok=True) @@ -17,25 +16,21 @@ def build_filepath(dirpath, filename, fileext, exist_ok=True, suffix="+"): filepath = os.path.normpath(os.path.join(dirpath, filename) + os.extsep + fileext) return filepath if exist_ok else alt_filepath(filepath, suffix) -def glob_recursively(dirpath, fileext, robust_letter_case=False): +def glob_recursively(dirpath, fileext): pattern = build_filepath(glob.escape(dirpath), os.path.join("**", "*"), glob.escape(fileext)) - ls = [f for f in glob.glob(pattern, recursive=True) if os.path.isfile(f)] - if robust_letter_case: - ls_dict = dict_groupby(ls, lambda f: os.path.basename(f).lower()) - exts = {fileext.lower(), fileext.capitalize(), fileext.upper()} - {fileext} - for e in exts: - pattern = build_filepath(glob.escape(dirpath), os.path.join("**", "*"), glob.escape(e)) - new = [] - for f in glob.glob(pattern, recursive=True): - if os.path.isfile(f): - name = os.path.basename(f).lower() - if name in ls_dict: - for l in ls_dict[name]: - if os.path.samefile(f, l): - break - else: - new.append(f) - else: - new.append(f) - ls += new - return ls + return [f for f in glob.glob(pattern, recursive=True) if os.path.isfile(f)] + +def relaxed_glob_recursively(dirpath, fileext): + lower, upper = fileext.lower(), fileext.upper() + ls = glob_recursively(dirpath, lower) + if lower == upper: + return ls + ls_upper = glob_recursively(dirpath, upper) + case_insensitive = len(ls) == len(ls_upper) > 0 and any(os.path.samefile(f, ls_upper[0]) for f in ls) + if case_insensitive: + return ls + ls += ls_upper + cap = fileext.capitalize() + if cap == lower or cap == upper: + return ls + return ls + glob_recursively(dirpath, cap) diff --git a/utilities/iter.py b/utilities/iter.py index 4a16a37..edddb74 100644 --- a/utilities/iter.py +++ b/utilities/iter.py @@ -8,13 +8,3 @@ def iter_batch(iterable, batch=1): iterator = iter(iterable) for i in iterator: yield itertools.chain([i], itertools.islice(iterator, batch - 1)) - -def dict_groupby(iterable, key): - d = dict() - for i in iterable: - k = key(i) - if k in d: - d[k].append(i) - else: - d[k] = [i] - return d