From 5b210acf0924dd7995294fa5b3f755436525c046 Mon Sep 17 00:00:00 2001 From: Eric Koch Date: Mon, 30 Sep 2024 10:23:45 -0400 Subject: [PATCH] Add additional checks and optional debug print mode to avoid passing empty skeletons to FilFinder2D --- fil_finder/filfinder2D.py | 40 +++++++++++++++++++++++++++------------ 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/fil_finder/filfinder2D.py b/fil_finder/filfinder2D.py index 11bd346..3e3146a 100644 --- a/fil_finder/filfinder2D.py +++ b/fil_finder/filfinder2D.py @@ -575,7 +575,8 @@ def analyze_skeletons(self, max_prune_iter=10, verbose=False, save_png=False, - save_name=None): + save_name=None, + debug=False,): ''' Prune skeleton structure and calculate the branch and longest-path @@ -614,6 +615,8 @@ def analyze_skeletons(self, Saves the plot made in verbose mode. Disabled by default. save_name : str, optional Prefix for the saved plots. + debug : bool, optional + Enables debug mode with extra printing ''' if relintens_thresh > 1.0 or relintens_thresh <= 0.0: @@ -636,7 +639,9 @@ def analyze_skeletons(self, else: skel_thresh = self.converter.to_pixel(skel_thresh) - self.skel_thresh = np.ceil(skel_thresh) + # Ensure the minimum is always >1 pixel. + + self.skel_thresh = min(np.ceil(skel_thresh), 1 * u.pix) # Set the minimum branch length to be the beam size. if branch_thresh is None: @@ -649,21 +654,32 @@ def analyze_skeletons(self, # Label individual filaments and define the set of filament objects labels, num = nd.label(self.skeleton, eight_con()) + if debug: + print(f"Found {num} filaments before removing short skeletons") + # Find the objects that don't satisfy skel_thresh - if self.skel_thresh > 0.: - obj_sums = nd.sum(self.skeleton, labels, range(1, num + 1)) - remove_fils = np.where(obj_sums <= self.skel_thresh.value)[0] + obj_sums = nd.sum(self.skeleton, labels, range(1, num + 1)) + remove_fils = np.where(obj_sums <= self.skel_thresh.value)[0] + + for lab in remove_fils: + if debug: + print(f"Removing {lab} with {obj_sums[lab]} pixels") + self.skeleton[np.where(labels == lab + 1)] = 0 + + # Relabel after deleting short skeletons. + labels, num = nd.label(self.skeleton, eight_con()) - for lab in remove_fils: - self.skeleton[np.where(labels == lab + 1)] = 0 + if debug: + print(f"Found {num} filaments after removing short skeletons") - # Relabel after deleting short skeletons. - labels, num = nd.label(self.skeleton, eight_con()) + self.filaments = [] + for lab in range(1, num + 1): + if debug: + print(f"Filament {lab} has {np.sum(labels == lab)} pixels") - self.filaments = [Filament2D(np.where(labels == lab), - converter=self.converter) for lab in - range(1, num + 1)] + self.filaments.append(Filament2D(np.where(labels == lab), + converter=self.converter)) # Now loop over the skeleton analysis for each filament object with concurrent.futures.ProcessPoolExecutor(nthreads) as executor: