Skip to content

Commit

Permalink
Redo the function get_node_ind.
Browse files Browse the repository at this point in the history
  • Loading branch information
Lin Wang committed Aug 4, 2023
1 parent 95519da commit b6f0867
Showing 1 changed file with 24 additions and 6 deletions.
30 changes: 24 additions & 6 deletions sleap_roots/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,36 @@ def get_node_ind(pts: np.ndarray, proximal: bool = True) -> np.ndarray:
return np.nan

if proximal:
# For proximal, we want the first non-NaN node, so we reverse the mask and use
# argmax
node_ind = (~nan_mask[:, ::-1]).argmax(axis=1)
node_ind = pts.shape[1] - node_ind - 1 # adjust indices because of reversal
# For proximal, we want the first non-NaN node in the first half root
# get the first half nan mask (exclude the base node)
node_proximal = nan_mask[:, 1 : int((nan_mask.shape[1] + 1) / 2)]
# get the nearest non-Nan node index
node_ind = np.argmax(~node_proximal, axis=-1)
# if there is no non-Nan node, set value of 99
node_ind[node_proximal.all(axis=1)] = 99
node_ind = node_ind + 1 # adjust indices by adding one (base node)
else:
# For distal, we can directly use argmax
node_ind = (~nan_mask).argmax(axis=1)
# For distal, we want the last non-NaN node in the last half root
# get the last half nan mask
node_distal = nan_mask[:, int(nan_mask.shape[1] / 2) :]
# get the farest non-Nan node
node_ind = (node_distal[:, ::-1] == False).argmax(axis=1)
node_ind[node_distal.all(axis=1)] = -95 # set value if no non-Nan node
node_ind = pts.shape[1] - node_ind - 1 # adjust indices by reversing

# reset indices of 0 (base node) if no non-Nan node
node_ind[node_ind == 100] = 0

# If pts was originally 2D, return a scalar instead of a single-element array
if pts.shape[0] == 1:
return node_ind[0]

# If only one root, return a scalar instead of a single-element array
if node_ind.shape[0] == 1:
return node_ind[0]

return node_ind


def get_root_angle(
pts: np.ndarray, node_ind: np.ndarray, proximal: bool = True, base_ind=0
Expand Down

0 comments on commit b6f0867

Please sign in to comment.