Skip to content

Commit

Permalink
Merge branch 'LauraWiggins/connect_nodes' of https://github.com/AFM-S…
Browse files Browse the repository at this point in the history
…PM/TopoStats into maxgamill-sheffield/cats
  • Loading branch information
MaxGamill-Sheffield committed Sep 26, 2023
2 parents 39e821f + e75ebab commit eec0a7b
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 3 deletions.
3 changes: 2 additions & 1 deletion topostats/grains.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,8 @@ def find_grains(self):
absolute=self.threshold_absolute,
)
for direction in self.direction:
grain_finding_workflow = "not_unet"
grain_finding_workflow = "unet"
grain_finding_workflow = "!unet"
if grain_finding_workflow == "unet":
LOGGER.info(f"[{self.filename}] : Finding {direction} grains via UNet")
self.directions[direction] = {}
Expand Down
84 changes: 82 additions & 2 deletions topostats/tracing/dnatracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from topoly import jones, homfly, params, reduce_structure, translate_code
import skimage.measure as skimage_measure
from tqdm import tqdm
import math as math

from topostats.logs.logs import LOGGER_NAME
from topostats.tracing.tracingfuncs import genTracingFuncs, reorderTrace
Expand Down Expand Up @@ -1208,6 +1209,7 @@ def get_node_stats(self) -> dict:
# np.savetxt(OUTPUT_DIR / "conv.txt", self.conv_skelly)
if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes
self.connect_close_nodes(self.conv_skelly, node_width=7e-9)
self.connected_nodes = self.connect_extended_nodes(self.connected_nodes)
# np.savetxt(OUTPUT_DIR / "img.txt", self.image)
# np.savetxt(OUTPUT_DIR / "untidied.txt", self.connected_nodes)
# self.connected_nodes = self.tidy_branches(self.connected_nodes, self.image)
Expand Down Expand Up @@ -1285,8 +1287,6 @@ def tidy_branches(self, connect_node_mask: np.ndarray, image: np.ndarray):
] = 1
# new_skeleton[node_centre[0]-node_wid//2-10:node_centre[0]+node_wid//2+10, node_centre[1]-node_len//2-10:node_centre[1]+node_len//2+10] = self.grain[node_centre[0]-node_wid//2-10:node_centre[0]+node_wid//2+10, node_centre[1]-node_len//2-10:node_centre[1]+node_len//2+10]

# np.savetxt(OUTPUT_DIR / "splodge.txt", new_skeleton)

new_skeleton = getSkeleton(image, new_skeleton).get_skeleton(method="joe", params={"height_bias": 0.6})
new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton(method="joe")
new_skeleton = getSkeleton(image, new_skeleton).get_skeleton(method="zhang") # cleanup around nibs
Expand Down Expand Up @@ -1334,6 +1334,84 @@ def highlight_node_centres(self, mask):

return small_node_mask

def connect_extended_nodes(self, connected_nodes):
just_nodes = connected_nodes.copy()
just_nodes[(connected_nodes == 1) | (connected_nodes == 2)] = 0 # remove branches & termini points
labelled = label(just_nodes)

just_branches = connected_nodes.copy()
just_branches[(connected_nodes == 3) | (connected_nodes == 2)] = 0 # remove node & termini points
just_branches[connected_nodes == 1] = labelled.max()+1
labelled_branches = label(just_branches)

def bounding_box(points):
x_coordinates, y_coordinates = zip(*points)

return [(min(x_coordinates), min(y_coordinates)), (max(x_coordinates), max(y_coordinates))]

def do_sets_touch(set_A, set_B):
# Iterate through coordinates in set_A and set_B
for point_A in set_A:
for point_B in set_B:
# Check if any coordinate in set_A is adjacent to any coordinate in set_B
if (
abs(point_A[0] - point_B[0]) <= 1
and abs(point_A[1] - point_B[1]) <= 1
):
return True # Sets touch
return False # Sets do not touch

emanating_branches_by_node = {} # Dictionary to store emanating branches for each label
nodes_with_odd_branches = [] # List to store nodes with three branches

for node in range(1, labelled.max()+1):
num_branches = 0
bounding = bounding_box(np.argwhere(labelled == node))
cropped_matrix = connected_nodes[bounding[0][0]-1:bounding[1][0] + 2, bounding[0][1]-1:bounding[1][1] + 2]
node_coords = np.argwhere(cropped_matrix == 3)
branch_coords = np.argwhere(cropped_matrix == 1)
for node_coord in node_coords:
for branch_coord in branch_coords:
distance = math.dist(node_coord, branch_coord)
if(distance <= math.sqrt(2)):
num_branches = num_branches+1
#num_branches = len(np.argwhere(cropped_matrix == 1))
print(f"node {node} has {num_branches} branches")

if(num_branches % 2 == 1):
nodes_with_odd_branches.append(node)
emanating_branches = [] # List to store emanating branches for the current label
for branch in range(1, labelled_branches.max() + 1):
touching = do_sets_touch(np.argwhere(labelled_branches == branch), np.argwhere(labelled == node))
if touching:
emanating_branches.append(branch)
emanating_branches_by_node[node] = emanating_branches # Store emanating branches for this label
print(node, emanating_branches_by_node[node])

# Iterate through the nodes and their emanating branches
for node1, branches1 in emanating_branches_by_node.items():
for node2, branches2 in emanating_branches_by_node.items():
if node1 != node2: # Avoid comparing a node with itself
# Find the common branches between the lists
common_branches = set(branches1) & set(branches2)
if common_branches:
min_length = float('inf') # Initialize with positive infinity
# Find the minimum length among all common branches
for shared_branch in common_branches:
length = len(np.argwhere(labelled_branches == shared_branch))
if length < min_length:
min_length = length
print(f"minimum length: {min_length}")
# Change the value to 3 only when len is minimal
for shared_branch in common_branches:
length = len(np.argwhere(labelled_branches == shared_branch))
if length == min_length:
print(shared_branch)
connected_nodes[labelled_branches == shared_branch] = 3

self.connected_nodes = connected_nodes
return self.connected_nodes

@staticmethod
def find_branch_starts(reduced_node_image: np.ndarray) -> np.ndarray:
"""Finds the corrdinates where the branches connect to the node region through binary dilation of the node.
Expand Down Expand Up @@ -1398,6 +1476,8 @@ def analyse_nodes(self, max_branch_length: float = 20e-9):

reduced_node_area = self._only_centre_branches(self.connected_nodes.copy(), (x, y))
branch_mask = reduced_node_area.copy()
#np.savetxt(OUTPUT_DIR / f"branch_mask{node_no}.txt", branch_mask)

branch_mask[branch_mask == 3] = 0
branch_mask[branch_mask == 2] = 1
node_coords = np.argwhere(reduced_node_area == 3)
Expand Down

0 comments on commit eec0a7b

Please sign in to comment.