From 7f87fcf16c616ba5b8063d8510a9b3e4648dc0f6 Mon Sep 17 00:00:00 2001 From: Andrew Ilyas Date: Tue, 4 Aug 2020 14:04:25 -0400 Subject: [PATCH] args for custom datasets --- robustness/datasets.py | 24 +++++++++++++++++------- robustness/loaders.py | 12 +++++++----- 2 files changed, 24 insertions(+), 12 deletions(-) diff --git a/robustness/datasets.py b/robustness/datasets.py index e67bebe..48860bc 100644 --- a/robustness/datasets.py +++ b/robustness/datasets.py @@ -72,21 +72,30 @@ def __init__(self, ds_name, data_path, **kwargs): transforms to apply to the validation images from the dataset """ - required_args = ['num_classes', 'mean', 'std', 'custom_class', - 'label_mapping', 'transform_train', 'transform_test'] - assert set(kwargs.keys()) == set(required_args), "Missing required args, only saw %s" % kwargs.keys() + required_args = ['num_classes', 'mean', 'std', + 'transform_train', 'transform_test'] + optional_args = ['custom_class', 'label_mapping', 'custom_class_args'] + + missing_args = set(required_args) - set(kwargs.keys()) + if len(missing_args) > 0: + raise ValueError("Missing required args %s" % missing_args) + + extra_args = set(kwargs.keys()) - set(required_args + optional_args) + if len(extra_args) > 0: + raise ValueError("Got unrecognized args %s" % extra_args) + final_kwargs = {k: kwargs.get(k, None) for k in required_args + optional_args} + self.ds_name = ds_name self.data_path = data_path - self.__dict__.update(kwargs) + self.__dict__.update(final_kwargs) def override_args(self, default_args, new_args): ''' Convenience method for overriding arguments. (Internal) ''' kwargs = {k: v for (k, v) in new_args.items() if v is not None} - extra_args = set(kwargs.keys()) - set(default_args.keys()) - if len(extra_args) > 0: raise ValueError(f"Invalid arguments: {extra_args}") for k in kwargs: + if not (k in default_args): continue req_type = type(default_args[k]) no_nones = (default_args[k] is not None) and (kwargs[k] is not None) if no_nones and (not isinstance(kwargs[k], req_type)): @@ -164,7 +173,8 @@ def make_loaders(self, workers, batch_size, data_aug=True, subset=None, only_val=only_val, seed=subset_seed, shuffle_train=shuffle_train, - shuffle_val=shuffle_val) + shuffle_val=shuffle_val, + custom_class_args=self.custom_class_args) class ImageNet(DataSet): ''' diff --git a/robustness/loaders.py b/robustness/loaders.py index d9a39de..4c1cfa2 100644 --- a/robustness/loaders.py +++ b/robustness/loaders.py @@ -23,7 +23,8 @@ def make_loaders(workers, batch_size, transforms, data_path, data_aug=True, custom_class=None, dataset="", label_mapping=None, subset=None, subset_type='rand', subset_start=0, val_batch_size=None, - only_val=False, shuffle_train=True, shuffle_val=True, seed=1): + only_val=False, shuffle_train=True, shuffle_val=True, seed=1, + custom_class_args=None): ''' **INTERNAL FUNCTION** @@ -58,11 +59,12 @@ def make_loaders(workers, batch_size, transforms, data_path, data_aug=True, test_set = folder.ImageFolder(root=test_path, transform=transform_test, label_mapping=label_mapping) else: + if custom_class_args is None: custom_class_args = {} if not only_val: - train_set = custom_class(root=data_path, train=True, - download=True, transform=transform_train) - test_set = custom_class(root=data_path, train=False, - download=True, transform=transform_test) + train_set = custom_class(root=data_path, train=True, download=True, + transform=transform_train, **custom_class_args) + test_set = custom_class(root=data_path, train=False, download=True, + transform=transform_test, **custom_class_args) if not only_val: attrs = ["samples", "train_data", "data"]