Skip to content

Commit

Permalink
Merge pull request #66 from MadryLab/custom_dataset_args
Browse files Browse the repository at this point in the history
args for custom datasets
  • Loading branch information
ShibaniSanturkar authored Aug 4, 2020
2 parents da6b6d3 + 7f87fcf commit 8058643
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
24 changes: 17 additions & 7 deletions robustness/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,30 @@ def __init__(self, ds_name, data_path, **kwargs):
transforms to apply to the validation images from the
dataset
"""
required_args = ['num_classes', 'mean', 'std', 'custom_class',
'label_mapping', 'transform_train', 'transform_test']
assert set(kwargs.keys()) == set(required_args), "Missing required args, only saw %s" % kwargs.keys()
required_args = ['num_classes', 'mean', 'std',
'transform_train', 'transform_test']
optional_args = ['custom_class', 'label_mapping', 'custom_class_args']

missing_args = set(required_args) - set(kwargs.keys())
if len(missing_args) > 0:
raise ValueError("Missing required args %s" % missing_args)

extra_args = set(kwargs.keys()) - set(required_args + optional_args)
if len(extra_args) > 0:
raise ValueError("Got unrecognized args %s" % extra_args)
final_kwargs = {k: kwargs.get(k, None) for k in required_args + optional_args}

self.ds_name = ds_name
self.data_path = data_path
self.__dict__.update(kwargs)
self.__dict__.update(final_kwargs)

def override_args(self, default_args, new_args):
'''
Convenience method for overriding arguments. (Internal)
'''
kwargs = {k: v for (k, v) in new_args.items() if v is not None}
extra_args = set(kwargs.keys()) - set(default_args.keys())
if len(extra_args) > 0: raise ValueError(f"Invalid arguments: {extra_args}")
for k in kwargs:
if not (k in default_args): continue
req_type = type(default_args[k])
no_nones = (default_args[k] is not None) and (kwargs[k] is not None)
if no_nones and (not isinstance(kwargs[k], req_type)):
Expand Down Expand Up @@ -164,7 +173,8 @@ def make_loaders(self, workers, batch_size, data_aug=True, subset=None,
only_val=only_val,
seed=subset_seed,
shuffle_train=shuffle_train,
shuffle_val=shuffle_val)
shuffle_val=shuffle_val,
custom_class_args=self.custom_class_args)

class ImageNet(DataSet):
'''
Expand Down
12 changes: 7 additions & 5 deletions robustness/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
def make_loaders(workers, batch_size, transforms, data_path, data_aug=True,
custom_class=None, dataset="", label_mapping=None, subset=None,
subset_type='rand', subset_start=0, val_batch_size=None,
only_val=False, shuffle_train=True, shuffle_val=True, seed=1):
only_val=False, shuffle_train=True, shuffle_val=True, seed=1,
custom_class_args=None):
'''
**INTERNAL FUNCTION**
Expand Down Expand Up @@ -58,11 +59,12 @@ def make_loaders(workers, batch_size, transforms, data_path, data_aug=True,
test_set = folder.ImageFolder(root=test_path, transform=transform_test,
label_mapping=label_mapping)
else:
if custom_class_args is None: custom_class_args = {}
if not only_val:
train_set = custom_class(root=data_path, train=True,
download=True, transform=transform_train)
test_set = custom_class(root=data_path, train=False,
download=True, transform=transform_test)
train_set = custom_class(root=data_path, train=True, download=True,
transform=transform_train, **custom_class_args)
test_set = custom_class(root=data_path, train=False, download=True,
transform=transform_test, **custom_class_args)

if not only_val:
attrs = ["samples", "train_data", "data"]
Expand Down

0 comments on commit 8058643

Please sign in to comment.