Skip to content

Commit

Permalink
Add single user input for defining a parallel processing pool
Browse files Browse the repository at this point in the history
  • Loading branch information
e-koch committed Sep 30, 2024
1 parent 5b210ac commit a0f9d69
Showing 1 changed file with 16 additions and 14 deletions.
30 changes: 16 additions & 14 deletions fil_finder/filfinder2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ class FilFinder2D(BaseInfoMixin):
save_name : str, optional
Sets the prefix name that is used for output files. Can be overridden
in ``save_fits`` and ``save_table``. Default is "FilFinder_output".
pool : None, concurrent.futures.ProcessPoolExecutor or mpi4py.futures.MPIPool , optional
Allows for parallel processing. The default of None will use a
concurrent.futures.ProcessPoolExecutor with `nthreads` processes.
nthreads : int, optional
The number of threads to use in parallel processing. Only used if
``pool`` is None to initialize concurrent.futures.ProcessPoolExecutor.
Examples
--------
Expand All @@ -91,7 +97,8 @@ class FilFinder2D(BaseInfoMixin):

def __init__(self, image, header=None, beamwidth=None, ang_scale=None,
distance=None, mask=None, save_name="FilFinder_output",
capture_pre_recombine_masks=False):
capture_pre_recombine_masks=False,
pool=None, nthreads=1):

# Accepts a numpy array or fits.PrimaryHDU
output = input_data(image, header)
Expand Down Expand Up @@ -158,6 +165,10 @@ def __init__(self, image, header=None, beamwidth=None, ang_scale=None,
self._pre_recombine_mask_objs = None
self._pre_recombine_mask_corners = None

if pool is None:
pool = concurrent.futures.ProcessPoolExecutor(max_workers=nthreads)
self.pool = pool

def preprocess_image(self, skip_flatten=False, flatten_percent=None):
'''
Preprocess and flatten the image before running the masking routine.
Expand Down Expand Up @@ -565,7 +576,6 @@ def medskel(self, verbose=False, save_png=False, rng=None):
p.clf()

def analyze_skeletons(self,
nthreads=1,
prune_criteria='all',
relintens_thresh=0.2,
nbeam_lengths=5,
Expand All @@ -585,8 +595,6 @@ def analyze_skeletons(self,
Parameters
----------
nthreads : int, optional
Number of threads to use to parallelize the skeleton analysis.
prune_criteria : {'all', 'intensity', 'length'}, optional
Choose the property to base pruning on. 'all' requires that the
branch fails to satisfy the length and relative intensity checks.
Expand Down Expand Up @@ -682,7 +690,7 @@ def analyze_skeletons(self,
converter=self.converter))

# Now loop over the skeleton analysis for each filament object
with concurrent.futures.ProcessPoolExecutor(nthreads) as executor:
with self.pool as executor:
futures = [executor.submit(fil.skeleton_analysis, self.image,
verbose=verbose,
save_png=save_png,
Expand Down Expand Up @@ -789,7 +797,6 @@ def end_pts(self):
return [fil.end_pts for fil in self.filaments]

def exec_rht(self,
nthreads=1,
radius=10 * u.pix,
ntheta=180, background_percentile=25,
branches=False, min_branch_length=3 * u.pix,
Expand All @@ -815,8 +822,6 @@ def exec_rht(self,
Parameters
----------
nthreads : int, optional
The number of threads to use.
radius : int
Sets the patch size that the RHT uses.
ntheta : int, optional
Expand Down Expand Up @@ -859,7 +864,7 @@ def exec_rht(self,


if branches:
with concurrent.futures.ProcessPoolExecutor(nthreads) as executor:
with self.pool as executor:
futures = [executor.submit(fil.rht_branch_analysis,
radius=radius,
ntheta=ntheta,
Expand All @@ -871,7 +876,7 @@ def exec_rht(self,


else:
with concurrent.futures.ProcessPoolExecutor(nthreads) as executor:
with self.pool as executor:
futures = [executor.submit(fil.rht_analysis,
radius=radius,
ntheta=ntheta,
Expand Down Expand Up @@ -943,7 +948,6 @@ def pre_recombine_mask_corners(self):
return self._pre_recombine_mask_corners

def find_widths(self,
nthreads=1,
max_dist=10 * u.pix,
pad_to_distance=0 * u.pix,
fit_model='gaussian_bkg',
Expand Down Expand Up @@ -973,8 +977,6 @@ def find_widths(self,
Parameters
----------
nthreads : int, optional
Number of threads to use.
max_dist : `~astropy.units.Quantity`, optional
Largest radius around the skeleton to create the profile from. This
can be given in physical, angular, or physical units.
Expand Down Expand Up @@ -1021,7 +1023,7 @@ def find_widths(self,
if save_name is None:
save_name = self.save_name

with concurrent.futures.ProcessPoolExecutor(nthreads) as executor:
with self.pool as executor:
futures = [executor.submit(fil.width_analysis, self.image,
all_skeleton_array=self.skeleton,
max_dist=max_dist,
Expand Down

0 comments on commit a0f9d69

Please sign in to comment.