Skip to content

Commit

Permalink
Fix ugly logic & code for case insensitive glob
Browse files Browse the repository at this point in the history
  • Loading branch information
curegit committed Sep 13, 2022
1 parent 5c47a83 commit 6e58aaa
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 34 deletions.
4 changes: 2 additions & 2 deletions stylegan/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from chainer.dataset import DatasetMixin
from utilities.stdio import eprint
from utilities.image import load_image
from utilities.filesys import glob_recursively
from utilities.filesys import relaxed_glob_recursively

class ImageDataset(DatasetMixin):

Expand All @@ -21,7 +21,7 @@ def __init__(self, directory, resolution):
eprint(f"Invalid dataset: {directory}")
eprint("Specified path is not a correct directory!")
raise RuntimeError("Input error")
self.image_files = sum([glob_recursively(directory, e, robust_letter_case=True) for e in ImageDataset.extensions], [])
self.image_files = sum([relaxed_glob_recursively(directory, e) for e in ImageDataset.extensions], [])
if not self.image_files:
eprint(f"Invalid dataset: {directory}")
eprint("No images found in the directory!")
Expand Down
39 changes: 17 additions & 22 deletions utilities/filesys.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
import os.path
import glob
from utilities.iter import dict_groupby

def mkdirs(dirpath):
os.makedirs(os.path.normpath(dirpath), exist_ok=True)
Expand All @@ -17,25 +16,21 @@ def build_filepath(dirpath, filename, fileext, exist_ok=True, suffix="+"):
filepath = os.path.normpath(os.path.join(dirpath, filename) + os.extsep + fileext)
return filepath if exist_ok else alt_filepath(filepath, suffix)

def glob_recursively(dirpath, fileext, robust_letter_case=False):
def glob_recursively(dirpath, fileext):
pattern = build_filepath(glob.escape(dirpath), os.path.join("**", "*"), glob.escape(fileext))
ls = [f for f in glob.glob(pattern, recursive=True) if os.path.isfile(f)]
if robust_letter_case:
ls_dict = dict_groupby(ls, lambda f: os.path.basename(f).lower())
exts = {fileext.lower(), fileext.capitalize(), fileext.upper()} - {fileext}
for e in exts:
pattern = build_filepath(glob.escape(dirpath), os.path.join("**", "*"), glob.escape(e))
new = []
for f in glob.glob(pattern, recursive=True):
if os.path.isfile(f):
name = os.path.basename(f).lower()
if name in ls_dict:
for l in ls_dict[name]:
if os.path.samefile(f, l):
break
else:
new.append(f)
else:
new.append(f)
ls += new
return ls
return [f for f in glob.glob(pattern, recursive=True) if os.path.isfile(f)]

def relaxed_glob_recursively(dirpath, fileext):
lower, upper = fileext.lower(), fileext.upper()
ls = glob_recursively(dirpath, lower)
if lower == upper:
return ls
ls_upper = glob_recursively(dirpath, upper)
case_insensitive = len(ls) == len(ls_upper) > 0 and any(os.path.samefile(f, ls_upper[0]) for f in ls)
if case_insensitive:
return ls
ls += ls_upper
cap = fileext.capitalize()
if cap == lower or cap == upper:
return ls
return ls + glob_recursively(dirpath, cap)
10 changes: 0 additions & 10 deletions utilities/iter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,3 @@ def iter_batch(iterable, batch=1):
iterator = iter(iterable)
for i in iterator:
yield itertools.chain([i], itertools.islice(iterator, batch - 1))

def dict_groupby(iterable, key):
d = dict()
for i in iterable:
k = key(i)
if k in d:
d[k].append(i)
else:
d[k] = [i]
return d

0 comments on commit 6e58aaa

Please sign in to comment.