Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a configurable filter list to HSC data set #53

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 37 additions & 10 deletions src/fibad/data_loaders/hsc_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def data_set(self):
self.config.get("path", "./data"),
transform=transform,
cutout_shape=self.config.get("crop_to", None),
filters=self.config.get("filters", None),
)

def data_loader(self, data_set):
Expand All @@ -65,7 +66,12 @@ def shape(self):

class HSCDataSet(Dataset):
def __init__(
self, path: Union[Path, str], *, transform=None, cutout_shape: Optional[tuple[int, int]] = None
self,
path: Union[Path, str],
*,
transform=None,
cutout_shape: Optional[tuple[int, int]] = None,
filters: Optional[list[str]] = None,
):
"""Initialize an HSC data set from a path. This involves several filesystem scan operations and will
ultimately open and read the header info of every fits file in the given directory
Expand All @@ -78,19 +84,24 @@ def __init__(
transform : torchvision.transforms.v2.Transform, optional
Transformation to apply to every image in the dataset, by default None
cutout_shape: tuple[int,int], optional
Forces all cutouts to be a particular pixel size. RuntimeError is raised if this size is larger
than the pixel dimension of any cutout in the dataset.
Forces all cutouts to be a particular pixel size. If this size is larger than the pixel dimension
of particular cutouts on the filesystem, those objects are dropped from the data set.
filters: list[str], optional
Forces all cutout tensors provided to be from the list of HSC filters provided. If provided, any
cutouts which do not have fits files corresponding to every filter in the list will be dropped
from the data set. Defaults to None. If not provided, the filters available on the filesystem for
the first object in the directory will be used.
"""
self.path = path
self.transform = transform

self.files = self._scan_file_names()
self.files = self._scan_file_names(filters)
self.dims = self._scan_file_dimensions()

# We choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by
# _prune_objects
filters_ref = list(list(self.files.values())[0])
# If no filters provided, we choose the first file in the dict as the prototypical set of filters
# Any objects lacking this full set of filters will be pruned by _prune_objects
filters_ref = list(list(self.files.values())[0]) if filters is None else filters

self.num_filters = len(filters_ref)

self.cutout_shape = cutout_shape
Expand All @@ -109,20 +120,36 @@ def __init__(

logger.info(f"HSC Data set loader has {len(self)} objects")

def _scan_file_names(self) -> dict[str, dict[str, str]]:
def _scan_file_names(self, filters: Optional[list[str]] = None) -> dict[str, dict[str, str]]:
"""Class initialization helper

Parameters
----------
filters : list[str], optional
If passed, only these filters will be scanned for from the data files. Defaults to None, which
corresponds to the standard set of filters ["HSC-G","HSC-R","HSC-I","HSC-Z","HSC-Y"].

Returns
-------
dict[str,dict[str,str]]
Nested dictionary where the first level maps object_id -> dict, and the second level maps
filter_name -> file name. Corresponds to self.files
"""

object_id_regex = r"[0-9]{17}"
filter_regex = r"HSC-[GRIZY]" if filters is None else "|".join(filters)
full_regex = f"({object_id_regex})_.*_({filter_regex}).fits"

files = {}
# Go scan the path for object ID's so we have a list.
for filepath in Path(self.path).glob("[0-9]*.fits"):
filename = filepath.name
m = re.match(r"([0-9]{17})_.*\_(HSC-[GRIZY]).fits", filename)
m = re.match(full_regex, filename)

# Skip files that don't match the pattern.
if m is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it doesn't make the process super slow, can we log the name of the file being skipped?

Copy link
Collaborator Author

@mtauraso mtauraso Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm more worried about log spam. Adding a debug or info level log here shouldn't slow things down unless the log is being emitted to a console.

I am thinking though that the better solution is to output that manifest fits table, which will have all the skipped files explicitly and not create a potential foot-gun for people changing the logging level to info/debug.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would advocate for @mtauraso's approach here. Perhaps a middle ground would be logging some summary metrics at the end along with a message saying to look in the manifest fits table for skipped files?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, putting the info in the manifest table sounds good! I also like @drewoldag's idea of some summary metrics at the end if it's easy to implement!

continue

object_id = m[1]
filter = m[2]

Expand Down
7 changes: 7 additions & 0 deletions src/fibad/fibad_default_config.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ path = "./data"
#
#crop_to = [100,100]

# Limit data loader to only particular filters when there are more in the data set.
#
# When not provided, the number of filters will be automatically gleaned from the data set.
# Defaults to not provided.
#
#filters = ["HSC-G", "HSC-R", "HSC-I", "HSC-Z", "HSC-Y"]

# Default PyTorch DataLoader parameters
batch_size = 500
shuffle = true
Expand Down
45 changes: 45 additions & 0 deletions tests/fibad/test_hsc_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,48 @@ def test_prune_size(caplog):
# We should warn that we are dropping objects and the reason
assert "Dropping object" in caplog.text
assert "too small" in caplog.text


def test_partial_filter(caplog):
"""Test to ensure when we only load some of the filters, only those filters end up in the dataset"""
caplog.set_level(logging.WARNING)
test_files = generate_files(num_objects=10, num_filters=5, shape=(262, 263))
with FakeFitsFS(test_files):
a = HSCDataSet("thispathdoesnotexist", filters=["HSC-G", "HSC-R"])

# 10 objects should load
assert len(a) == 10

# The number of filters, and image dimensions should be correct
assert a.shape() == (2, 262, 263)

# No warnings should be printed
assert caplog.text == ""


def test_partial_filter_prune_warn_1_percent(caplog):
"""Test to ensure when a the user supplies a filter list and >1% of loaded objects are
missing a filter, that is a warning and that the resulting dataset drops the objects that
are missing filters.
"""
caplog.set_level(logging.WARNING)

# Generate two files which
test_files = generate_files(num_objects=98, num_filters=3, shape=(100, 100))
# Object 101 is missing the HSC-G and HSC-I filters, we only provide the R filter
test_files["00000000000000101_missing_g_HSC-R.fits"] = (100, 100)

with FakeFitsFS(test_files):
a = HSCDataSet("thispathdoesnotexist", filters=["HSC-R", "HSC-I"])

# We should have the correct number of objects
assert len(a) == 98

# Object 101 should not be loaded
assert "00000000000000101" not in a

# We should Error log because greater than 5% of the objects were pruned
assert "Greater than 1% of objects in the data directory were pruned." in caplog.text

# We should warn that we dropped an object explicitly
assert "Dropping object" in caplog.text