Skip to content

Commit

Permalink
Move config parameters to a different table so that we can define a d…
Browse files Browse the repository at this point in the history
…ataloader easier (#91)
  • Loading branch information
drewoldag authored Oct 15, 2024
1 parent de91389 commit dad7062
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/fibad/data_sets/hsc_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def __init__(self, config):
# Because it goes from unbounded NN output space -> [-1,1] with tanh in its decode step.
transform = Lambda(lambd=np.tanh)

crop_to = config["data_loader"]["crop_to"]
filters = config["data_loader"]["filters"]
crop_to = config["data_set"]["crop_to"]
filters = config["data_set"]["filters"]

self._init_from_path(
config["general"]["data_dir"],
Expand Down
2 changes: 1 addition & 1 deletion src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ resume = false
# e.g. "user_package.submodule.ExternalDataLoader" or "HSCDataLoader"
name = "CifarDataSet"

[data_loader]
# Pixel dimensions used to crop all images prior to loading. Will prune any images that are too small.
#
# If not provided by user, the default of 'false' scans the directory for the smallest dimensioned files, and
Expand All @@ -87,6 +86,7 @@ crop_to = false
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]
filters = false

[data_loader]
# Default PyTorch DataLoader parameters
batch_size = 32
shuffle = true
Expand Down
12 changes: 1 addition & 11 deletions src/fibad/pytorch_ignite.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,7 @@ def dist_data_loader(data_set: Dataset, config: ConfigDict):
Dataloader (or an ignite-wrapped equivalent)
This is the distributed dataloader, formed by calling ignite.distributed.auto_dataloader
"""
# ~ idist.auto_dataloader will accept a **kwargs parameter, and pass values
# ~ through to the underlying pytorch DataLoader.
# ~ Currently, our config includes unexpected keys like `name`, that cause
# ~ an exception. It would be nice to reduce this to:
# ~ `data_loader = idist.auto_dataloader(data_set, **config)`
return idist.auto_dataloader(
data_set,
batch_size=config["data_loader"]["batch_size"],
shuffle=config["data_loader"]["shuffle"],
num_workers=config["data_loader"]["num_workers"],
)
return idist.auto_dataloader(data_set, **config["data_loader"])


def create_engine(funcname: str, device: torch.device, model: torch.nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def mkconfig(crop_to=False, filters=False):
"""
return {
"general": {"data_dir": "thispathdoesnotexist"},
"data_loader": {
"data_set": {
"crop_to": crop_to,
"filters": filters,
},
Expand Down

0 comments on commit dad7062

Please sign in to comment.