Skip to content

Commit

Permalink
Add additional checks and optional debug print mode to avoid passing …
Browse files Browse the repository at this point in the history
…empty skeletons to FilFinder2D
  • Loading branch information
e-koch committed Sep 30, 2024
1 parent 07d38f3 commit 5b210ac
Showing 1 changed file with 28 additions and 12 deletions.
40 changes: 28 additions & 12 deletions fil_finder/filfinder2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,8 @@ def analyze_skeletons(self,
max_prune_iter=10,
verbose=False,
save_png=False,
save_name=None):
save_name=None,
debug=False,):
'''
Prune skeleton structure and calculate the branch and longest-path
Expand Down Expand Up @@ -614,6 +615,8 @@ def analyze_skeletons(self,
Saves the plot made in verbose mode. Disabled by default.
save_name : str, optional
Prefix for the saved plots.
debug : bool, optional
Enables debug mode with extra printing
'''

if relintens_thresh > 1.0 or relintens_thresh <= 0.0:
Expand All @@ -636,7 +639,9 @@ def analyze_skeletons(self,
else:
skel_thresh = self.converter.to_pixel(skel_thresh)

self.skel_thresh = np.ceil(skel_thresh)
# Ensure the minimum is always >1 pixel.

self.skel_thresh = min(np.ceil(skel_thresh), 1 * u.pix)

# Set the minimum branch length to be the beam size.
if branch_thresh is None:
Expand All @@ -649,21 +654,32 @@ def analyze_skeletons(self,
# Label individual filaments and define the set of filament objects
labels, num = nd.label(self.skeleton, eight_con())

if debug:
print(f"Found {num} filaments before removing short skeletons")

# Find the objects that don't satisfy skel_thresh
if self.skel_thresh > 0.:
obj_sums = nd.sum(self.skeleton, labels, range(1, num + 1))
remove_fils = np.where(obj_sums <= self.skel_thresh.value)[0]
obj_sums = nd.sum(self.skeleton, labels, range(1, num + 1))
remove_fils = np.where(obj_sums <= self.skel_thresh.value)[0]

for lab in remove_fils:
if debug:
print(f"Removing {lab} with {obj_sums[lab]} pixels")
self.skeleton[np.where(labels == lab + 1)] = 0

# Relabel after deleting short skeletons.
labels, num = nd.label(self.skeleton, eight_con())

for lab in remove_fils:
self.skeleton[np.where(labels == lab + 1)] = 0
if debug:
print(f"Found {num} filaments after removing short skeletons")

# Relabel after deleting short skeletons.
labels, num = nd.label(self.skeleton, eight_con())
self.filaments = []

for lab in range(1, num + 1):
if debug:
print(f"Filament {lab} has {np.sum(labels == lab)} pixels")

self.filaments = [Filament2D(np.where(labels == lab),
converter=self.converter) for lab in
range(1, num + 1)]
self.filaments.append(Filament2D(np.where(labels == lab),
converter=self.converter))

# Now loop over the skeleton analysis for each filament object
with concurrent.futures.ProcessPoolExecutor(nthreads) as executor:
Expand Down

0 comments on commit 5b210ac

Please sign in to comment.