From 7e396b579de4c38dd6b26aca487c1aab729b356b Mon Sep 17 00:00:00 2001 From: Eric Koch Date: Thu, 15 Aug 2024 12:07:00 -0400 Subject: [PATCH] Add parallelization for FilFinder2D.analyze_skeletons --- fil_finder/filament.py | 9 +++++- fil_finder/filfinder2D.py | 60 ++++++++++++++++++++++++++++----------- 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/fil_finder/filament.py b/fil_finder/filament.py index ec05578..065d664 100644 --- a/fil_finder/filament.py +++ b/fil_finder/filament.py @@ -185,7 +185,8 @@ def skeleton(self, pad_size=0, corner_pix=None, out_type='all'): def skeleton_analysis(self, image, verbose=False, save_png=False, save_name=None, prune_criteria='all', relintens_thresh=0.2, max_prune_iter=10, - branch_thresh=0 * u.pix): + branch_thresh=0 * u.pix, + return_self=False): ''' Run the skeleton analysis. @@ -215,6 +216,9 @@ def skeleton_analysis(self, image, verbose=False, save_png=False, Maximum number of pruning iterations to apply. branch_thresh : `~astropy.units.Quantity`, optional Minimum length for a branch to be eligible to be pruned. + return_self : bool, optional + Return the Filament2D object after skeleton analysis. This is needed + for parallel processing in FilFinder2D. ''' # NOTE: @@ -391,6 +395,9 @@ def skeleton_analysis(self, image, verbose=False, save_png=False, 'number': branch_properties['number'][0], 'pixels': branch_properties['pixels'][0]} + if return_self: + return self + @property def branch_properties(self): ''' diff --git a/fil_finder/filfinder2D.py b/fil_finder/filfinder2D.py index beee31e..77e5e28 100644 --- a/fil_finder/filfinder2D.py +++ b/fil_finder/filfinder2D.py @@ -14,6 +14,7 @@ import os import time import warnings +import concurrent.futures from .pixel_ident import recombine_skeletons, isolateregions from .utilities import eight_con, round_to_odd, threshold_local, in_ipynb @@ -563,11 +564,18 @@ def medskel(self, verbose=False, save_png=False, rng=None): if in_ipynb(): p.clf() - def analyze_skeletons(self, prune_criteria='all', relintens_thresh=0.2, - nbeam_lengths=5, branch_nbeam_lengths=3, - skel_thresh=None, branch_thresh=None, + def analyze_skeletons(self, + nthreads=1, + prune_criteria='all', + relintens_thresh=0.2, + nbeam_lengths=5, + branch_nbeam_lengths=3, + skel_thresh=None, + branch_thresh=None, max_prune_iter=10, - verbose=False, save_png=False, save_name=None): + verbose=False, + save_png=False, + save_name=None): ''' Prune skeleton structure and calculate the branch and longest-path @@ -576,6 +584,8 @@ def analyze_skeletons(self, prune_criteria='all', relintens_thresh=0.2, Parameters ---------- + nthreads : int, optional + Number of threads to use to parallelize the skeleton analysis. prune_criteria : {'all', 'intensity', 'length'}, optional Choose the property to base pruning on. 'all' requires that the branch fails to satisfy the length and relative intensity checks. @@ -650,25 +660,43 @@ def analyze_skeletons(self, prune_criteria='all', relintens_thresh=0.2, # Relabel after deleting short skeletons. labels, num = nd.label(self.skeleton, eight_con()) + self.filaments = [Filament2D(np.where(labels == lab), converter=self.converter) for lab in range(1, num + 1)] + with concurrent.futures.ProcessPoolExecutor(nthreads) as executor: + futures = [executor.submit(fil.skeleton_analysis, self.image, + verbose=verbose, + save_png=save_png, + save_name=save_name, + prune_criteria=prune_criteria, + relintens_thresh=relintens_thresh, + branch_thresh=self.branch_thresh, + max_prune_iter=max_prune_iter, + return_self=True) + for fil in self.filaments] + self.filaments = [future.result() for future in futures] + + print(self.filaments[0].length(),) + print(self.filaments[0].branch_properties['length'],) + print(self.filaments[0].pixel_coords) + self.number_of_filaments = num # Now loop over the skeleton analysis for each filament object - for n, fil in enumerate(self.filaments): - savename = "{0}_{1}".format(save_name, n) - if verbose: - print("Filament: %s / %s" % (n + 1, self.number_of_filaments)) - - fil.skeleton_analysis(self.image, verbose=verbose, - save_png=save_png, - save_name=savename, - prune_criteria=prune_criteria, - relintens_thresh=relintens_thresh, - branch_thresh=self.branch_thresh, - max_prune_iter=max_prune_iter) + # for n, fil in enumerate(self.filaments): + # savename = "{0}_{1}".format(save_name, n) + # if verbose: + # print("Filament: %s / %s" % (n + 1, self.number_of_filaments)) + + # fil.skeleton_analysis(self.image, verbose=verbose, + # save_png=save_png, + # save_name=savename, + # prune_criteria=prune_criteria, + # relintens_thresh=relintens_thresh, + # branch_thresh=self.branch_thresh, + # max_prune_iter=max_prune_iter) self.array_offsets = [fil.pixel_extents for fil in self.filaments]