diff --git a/fil_finder/filament.py b/fil_finder/filament.py index 2dd358e..32d961a 100644 --- a/fil_finder/filament.py +++ b/fil_finder/filament.py @@ -1466,7 +1466,10 @@ def branch_table(self, include_rht=False): names=branch_data.keys()) return tab - def save_fits(self, savename, image, pad_size=20 * u.pix, header=None, + def save_fits(self, savename, image, + image_list=None, + pad_size=20 * u.pix, + header=None, model_kwargs={}, **kwargs): ''' @@ -1556,6 +1559,20 @@ def save_fits(self, savename, image, pad_size=20 * u.pix, header=None, hdulist = fits.HDUList([hdu, skel_hdu, skel_lp_hdu, model_hdu, tab_hdu]) + # If image_list is provided, save cutouts from the image list + if image_list is not None: + for key in image_list: + img = image_list[key] + img = pad_image(img, self.pixel_extents, pad_size) + if img.shape != skels.shape: + img = self.image_slicer(img, skels.shape, + pad_size=pad_size) + + img_hdu = fits.ImageHDU(img, header) + img_hdu.name = key.upper() + + hdulist.append(img_hdu) + hdulist.writeto(savename, **kwargs) def to_pickle(self, savename): diff --git a/fil_finder/filfinder2D.py b/fil_finder/filfinder2D.py index 56756a8..f343b6c 100644 --- a/fil_finder/filfinder2D.py +++ b/fil_finder/filfinder2D.py @@ -1432,7 +1432,10 @@ def save_fits(self, save_name=None, out_hdu.writeto("{0}_image_output.fits".format(save_name), **kwargs) - def save_stamp_fits(self, save_name=None, pad_size=20 * u.pix, + def save_stamp_fits(self, + image_list=None, + save_name=None, + pad_size=20 * u.pix, model_kwargs={}, **kwargs): ''' @@ -1444,6 +1447,10 @@ def save_stamp_fits(self, save_name=None, pad_size=20 * u.pix, Parameters ---------- + image_list : dict, optional + Dictionary of arrays to save matching the pixel extents of each filament. + The shape of each array *must* be the same shape as the original image + given to `~FilFinder2D`. save_name : str, optional The prefix for the saved file. If None, the save name specified when `~FilFinder2D` was first called. @@ -1459,10 +1466,21 @@ def save_stamp_fits(self, save_name=None, pad_size=20 * u.pix, else: save_name = os.path.splitext(save_name)[0] + if image_list is not None: + for ii, key in enumerate(image_list): + this_image = image_list[key] + if this_image.shape != self.image.shape: + raise ValueError("All images in image_list must be same shape as fil.image. " + f"For index {ii}, found shape {this_image.shape} not {self.image.shape}") + + for n, fil in enumerate(self.filaments): - savename = "{0}_stamp_{1}.fits".format(save_name, n) + savename = f"{save_name}_stamp_{n}.fits" - fil.save_fits(savename, self.image, pad_size=pad_size, + fil.save_fits(savename, + self.image, + image_list=image_list, + pad_size=pad_size, model_kwargs=model_kwargs, **kwargs)