Skip to content

Commit

Permalink
weights skel graph, + 0 cross contingency, avg idx error, overlap nod…
Browse files Browse the repository at this point in the history
…e coords
  • Loading branch information
MaxGamill-Sheffield committed Sep 28, 2023
1 parent 656f872 commit 850f668
Showing 1 changed file with 43 additions and 105 deletions.
148 changes: 43 additions & 105 deletions topostats/tracing/dnatracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1145,7 +1145,7 @@ def __init__(
self.px_2_nm = px_2_nm
self.n_grain = n_grain

self.skel_graph = self.skeleton_image_to_graph(self.skeleton)
self.skel_graph = None
sigma = (-3.5 / 3) * self.px_2_nm * 1e9 + 15.5 / 3
self.hess = self.detect_ridges(self.image * 1e9, 4)
# np.savetxt(OUTPUT_DIR / "hess.txt", self.hess)
Expand Down Expand Up @@ -1209,9 +1209,11 @@ def get_node_stats(self) -> dict:
# np.savetxt(OUTPUT_DIR / "conv.txt", self.conv_skelly)
if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes
self.connect_close_nodes(self.conv_skelly, node_width=7e-9)
# TODO: maybe instead of connecting odds via sole branches - try only via shortest path?
self.connected_nodes = self.connect_extended_nodes(self.connected_nodes)
# np.savetxt(OUTPUT_DIR / "img.txt", self.image)
# np.savetxt(OUTPUT_DIR / "untidied.txt", self.connected_nodes)
#np.savetxt(OUTPUT_DIR / "untidied.txt", self.connected_nodes)
plt.imsave(OUTPUT_DIR / "connected_nodes.png", self.connected_nodes)
# self.connected_nodes = self.tidy_branches(self.connected_nodes, self.image)
self.node_centre_mask = self.highlight_node_centres(self.connected_nodes)
# np.savetxt(OUTPUT_DIR / "tidied.txt", self.connected_nodes)
Expand Down Expand Up @@ -1241,7 +1243,12 @@ def skeleton_image_to_graph(skel):
if skel[curNeighPos[0], curNeighPos[1]] > 0:
idx_coord = skeImPos[0, idx], skeImPos[1, idx]
curNeigh_coord = curNeighPos[0], curNeighPos[1]
g.add_edge(idx_coord, curNeigh_coord)
# assign lower weight to nodes if not a binary image
if skel[idx_coord] == 3 and skel[curNeigh_coord] == 3:
weight = 0
else:
weight = 1
g.add_edge(idx_coord, curNeigh_coord, weight=weight)
g.graph["physicalPos"] = skeImPos.T
return g

Expand Down Expand Up @@ -1449,7 +1456,7 @@ def analyse_nodes(self, max_branch_length: float = 20e-9):
# check whether average trace resides inside the grain mask
dilate = ndimage.binary_dilation(self.skeleton, iterations=2)
average_trace_advised = dilate[self.smoothed_grain == 1].sum() == dilate.sum()
LOGGER.info(f"Branch height traces will be averaged: {average_trace_advised}")
LOGGER.info(f"[{self.filename}] : Branch height traces will be averaged: {average_trace_advised}")

# iterate over the nodes to find areas
# node_dict = {}
Expand All @@ -1474,17 +1481,15 @@ def analyse_nodes(self, max_branch_length: float = 20e-9):
y + int(max_length_px * 1.2),
)

reduced_node_area = self._only_centre_branches(self.connected_nodes.copy(), (x, y))
# reduce the skeleton area
reduced_node_area = self._only_centre_branches(self.connected_nodes, (x, y))
self.skel_graph = self.skeleton_image_to_graph(reduced_node_area)
branch_mask = reduced_node_area.copy()
#np.savetxt(OUTPUT_DIR / f"branch_mask{node_no}.txt", branch_mask)

branch_mask[branch_mask == 3] = 0
branch_mask[branch_mask == 2] = 1
node_coords = np.argwhere(reduced_node_area == 3)

# plt.imsave(OUTPUT_DIR / "reduced_node_area.png", reduced_node_area)
# np.savetxt(OUTPUT_DIR / "reduced_node_area.txt", reduced_node_area)

error = False # to see if node too complex or region too small

branch_start_coords = self.find_branch_starts(reduced_node_area)
Expand Down Expand Up @@ -1540,30 +1545,33 @@ def analyse_nodes(self, max_branch_length: float = 20e-9):
)
# Get graphical shortest path between branch ends on the skeleton
crossing = nx.shortest_path(
self.skel_graph, tuple(branch_1_coords[-1]), tuple(branch_2_coords[0])
self.skel_graph,
source=tuple(branch_1_coords[-1]),
target=tuple(branch_2_coords[0]),
weight="weight"
)
crossing = np.asarray(crossing[1:-1]) # remove start and end points & turn into array
# Branch coords and crossing
branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords])
if crossing.shape == (0,):
branch_coords = np.vstack([branch_1_coords, branch_2_coords])
else:
branch_coords = np.vstack([branch_1_coords, crossing, branch_2_coords])
# make images of single branch joined and multiple branches joined
single_branch_img = np.zeros_like(self.skeleton)
single_branch_img[branch_coords[:, 0], branch_coords[:, 1]] = 1
single_branch_coords = self.order_branch(single_branch_img, [0, 0])
# calc image-wide coords
matched_branches[i]["ordered_coords"] = single_branch_coords
# get heights and trace distance of branch
if average_trace_advised:
try:
assert average_trace_advised
# np.savetxt(OUTPUT_DIR / "area.txt",image_area)
tmp = single_branch_img.copy()
tmp[x, y] = 2
print(x,y)
plt.imsave(OUTPUT_DIR / "sing.png", tmp)
plt.imsave(OUTPUT_DIR / "nodes.png", self.all_connected_nodes)
distances, heights, mask, _ = self.average_height_trace(
self.image, single_branch_img, single_branch_coords, [x, y]
) # hess_area
matched_branches[i]["avg_mask"] = mask
else:
except (AssertionError, IndexError) as e: # Assertion - avg trace not advised, Index - wiggy branches
average_trace_advised = False
distances = self.coord_dist_rad(
single_branch_coords, [x, y]
) # self.coord_dist(single_branch_coords)
Expand Down Expand Up @@ -1654,18 +1662,18 @@ def analyse_nodes(self, max_branch_length: float = 20e-9):
"branch_stats": matched_branches,
"node_stats": {
"node_coords": node_coords,
"node_area_image": self.image[
image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
], # self.hess
"node_area_grain": self.grain[
image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
],
"node_area_skeleton": reduced_node_area[
image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
],
"node_branch_mask": branch_img[
image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
],
"node_area_image": self.image, #[
#image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
#], # self.hess
"node_area_grain": self.grain, #[
#image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
#],
"node_area_skeleton": reduced_node_area, #[
#image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
#],
"node_branch_mask": branch_img, #[
#image_slices[0] : image_slices[1], image_slices[2] : image_slices[3]
#],
"node_avg_mask": avg_img,
},
}
Expand Down Expand Up @@ -1724,7 +1732,6 @@ def order_branch(self, binary_image: np.ndarray, anchor: list):
np.ndarray
An array of ordered cordinates.
"""
print("ORDER BRANCH:")
mask = binary_image.copy()

if len(np.argwhere(mask == 1)) < 3: # if < 3 coords just return them
Expand Down Expand Up @@ -2377,7 +2384,6 @@ def average_height_trace(
avg2.append([mid_dist, y])
avg1 = np.asarray(avg1)
avg2 = np.asarray(avg2)
print("AVGs: ", avg1.shape, avg2.shape)
# ensure arrays are same length to average
temp_x = branch_dist_norm[np.isin(branch_dist_norm, avg1[:, 0])]
common_dists = avg2[:, 0][np.isin(avg2[:, 0], temp_x)]
Expand Down Expand Up @@ -2499,11 +2505,11 @@ def compile_trace(self):
temp_distances = []
temp_fwhms = []
for _, branch_stats in stats["branch_stats"].items():
temp_nodes.append(["node_coords"])
temp_coords.append(branch_stats["ordered_coords"])
temp__heights.append(branch_stats["heights"])
temp_distances.append(branch_stats["distances"])
temp_fwhms.append(branch_stats["fwhm2"][0])
temp_nodes.append(stats["node_stats"]["node_coords"])
node_coords.append(temp_nodes)
crossing_coords.append(temp_coords)
crossing_heights.append(temp__heights)
Expand Down Expand Up @@ -2558,7 +2564,6 @@ def compile_trace(self):

# np.savetxt(OUTPUT_DIR / "cross_add.txt", cross_add)
LOGGER.info(f"[{self.filename}] Getting coordinate trace")
# coord_trace = self.trace_mol(ordered, cross_add)

coord_trace, simple_trace = self.simple_xyz_trace(ordered, cross_add, z, n=100)

Expand Down Expand Up @@ -2615,82 +2620,15 @@ def remove_common_values(arr1, arr2, retain=[]):

return np.asarray(filtered_arr1)

def trace_mol(self, ordered_segment_coords, both_img):
"""There's a problem with the code in that when tracing a non-circular molecule, the index
of a 'new' molecule will start at a new index (fine) but the ordering to the last coord may
start at the wrong section - should be moved to the end?
New order (?):
choose 0th index of coord section
-----
-----
check if there's an endpoint and order from there first
add coords to trace
remove coords from remaining image
get coords at end of trace
get index of area at end of reminaing
get next smallest index
----- repeat until terminated
----- repeat until all segments covered
Parameters
----------
ordered_segment_coords : _type_
_description_
both_img : _type_
_description_
Returns
-------
_type_
_description_
"""
mol_num = 0
mol_coords = []
remaining = both_img.copy().astype(np.int32)

remaining2 = remaining.copy()

binary_remaining = remaining.copy()
binary_remaining[binary_remaining != 0] = 1
endpoints = np.unique(remaining[convolve_skelly(binary_remaining) == 2]) # uniq incase of whole mol

while remaining.max() != 0:
# select endpoint to start if there is one
endpoints = [i for i in endpoints if i in np.unique(remaining)] # remove if removed from remaining
if endpoints:
coord_idx = endpoints[0] - 1
else: # if no endpoints, just a loop
coord_idx = np.unique(remaining)[1] - 1 # avoid choosing 0
coord_trace = np.empty((0, 2)).astype(np.int32)
mol_num += 1
while coord_idx > -1: # either cycled through all or hits terminus -> all will be just background
remaining[remaining == coord_idx + 1] = 0
trace_segment = self.get_trace_segment(remaining, ordered_segment_coords, coord_idx)
if len(coord_trace) > 0: # can only order when there's a reference point / segment
trace_segment = self.remove_duplicates(
trace_segment, prev_segment
) # remove overlaps in trace (may be more efficient to do it on the prev segment)
trace_segment = self.order_from_end(coord_trace[-1], trace_segment)
prev_segment = trace_segment.copy()
coord_trace = np.append(coord_trace, trace_segment.astype(np.int32), axis=0)
x, y = coord_trace[-1]
coord_idx = remaining[x - 1 : x + 2, y - 1 : y + 2].max() - 1 # should only be one value
mol_coords.append(coord_trace)

print(f"Mols in trace: {len(mol_coords)}")

return mol_coords

def simple_xyz_trace(self, ordered_segment_coords, both_img, zs, n=100):
"""Obtains a trace and simplified trace of the molecule by following connected segments."""
np.save(OUTPUT_DIR / "both.txt", both_img)
mol_coords = []
simple_coords = []
remaining = both_img.copy().astype(np.int32)
binary_remaining = remaining.copy()
binary_remaining[binary_remaining != 0] = 1
endpoints = np.unique(remaining[convolve_skelly(binary_remaining) == 2]) # uniq incase of whole mol
endpoints = np.unique(remaining[convolve_skelly(remaining) == 2]) # uniq incase of whole mol
n_points_p_seg = (n - 2 * remaining.max()) // remaining.max()

while remaining.max() != 0:
# select endpoint to start if there is one
endpoints = [i for i in endpoints if i in np.unique(remaining)] # remove if removed from remaining
Expand Down

0 comments on commit 850f668

Please sign in to comment.