Skip to content

Commit

Permalink
feat: new script to split segm mask along line mother-bud
Browse files Browse the repository at this point in the history
  • Loading branch information
ElpadoCan committed Nov 21, 2024
1 parent 731f3e5 commit 0015c4a
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 40 deletions.
18 changes: 4 additions & 14 deletions cellacdc/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,19 @@
import numpy as np
import pandas as pd

from . import printl
from . import printl, core

def split_segm_masks_mother_bud_line(lab, obj, obj_bud, interc_perp, slope_perp):
def split_segm_masks_mother_bud_line(lab, obj, obj_bud, ref_p1, ref_p2):
import matplotlib.pyplot as plt

lab = np.zeros_like(lab)
lab[obj.slice][obj.image] = obj.label
lab[obj_bud.slice][obj_bud.image] = obj_bud.label

y0 = 0
x0 = (y0 - interc_perp)/slope_perp

x1 = lab.shape[1]
y1 = slope_perp*x1 + interc_perp

x2 = 0
y2 = interc_perp

y3 = lab.shape[0]
x3 = (y3 - interc_perp)/slope_perp
(x_ref_0, y_ref_0), (x_ref1, y_ref1) = ref_p1, ref_p2

plt.imshow(lab)
plt.plot([x0, x1, x2, x3], [y0, y1, y2, y3], 'r')
plt.plot([x_ref_0, x_ref1], [y_ref_0, y_ref1], 'r')
plt.show()

import pdb; pdb.set_trace()
Expand Down
109 changes: 87 additions & 22 deletions cellacdc/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2616,21 +2616,17 @@ def split_segm_masks_mother_bud_line(
try:
ccs = acdc_df.at[(frame_i, obj.label), 'cell_cycle_stage']
except Exception as err:
pbar.update()
continue

if ccs != 'S':
pbar.update()
continue

try:
relationship = acdc_df.at[(frame_i, obj.label), 'relationship']
except Exception as err:
pbar.update()
continue

if relationship == 'bud':
pbar.update()
continue

bud_ID = int(acdc_df.at[(frame_i, obj.label), 'relative_ID'])
Expand All @@ -2644,36 +2640,103 @@ def split_segm_masks_mother_bud_line(
if slope_mb != 0:
slope_perp = -1/slope_mb
interc_perp = yc_m - xc_m*slope_perp
else:
slope_perp = np.inf
interc_perp = np.nan

ref_p1, ref_p2 = get_split_line_ref_points_img(
lab, slope_perp, interc_perp, xc_m, yc_m
)

if debug:
from cellacdc import _debug
_debug.split_segm_masks_mother_bud_line(
lab, obj, obj_bud, interc_perp, slope_perp
lab, obj, obj_bud, ref_p1, ref_p2
)





for z, lab_split in enumerate(segm_data_to_split[frame_i]):
lab_split_yy, lab_split_xx = np.nonzero(lab_split==obj.label)
if len(lab_split_yy) == 0:
continue

query_points = np.column_stack((lab_split_xx, lab_split_yy))
close_to_bud_mask = classify_points_plane_split_by_line(
ref_p1, ref_p2, query_points, (xc_b, yc_b)
)

split_close_yy = lab_split_yy[close_to_bud_mask]
split_close_xx = lab_split_xx[close_to_bud_mask]

split_segm_close[frame_i, z, split_close_yy, split_close_xx] = (
obj.label
)

split_away_yy = lab_split_yy[~close_to_bud_mask]
split_away_xx = lab_split_xx[~close_to_bud_mask]

split_segm_away[frame_i, z, split_away_yy, split_away_xx] = (
obj.label
)

pbar.update()
pbar.close()
pbar.close()

return split_segm_close, split_segm_away

def classify_points_plane_split_by_line(
p1, p2, query_points: np.ndarray, relative_to_p
):
"""Classify points on plane crossed by a line connecting p1 and p2 relative
to `relative_to_p` point
Parameters
----------
p1 : (x, y) of floats
First point of the line
p2 : (x, y) of floats
Second point
query_points : (N, 2) np.ndarray
(x, y) coordinates of the points to classify
References
----------
https://stackoverflow.com/questions/45766534/finding-cross-product-to-find-points-above-below-a-line-in-matplotlib
"""
relative_p_arr = np.array([relative_to_p])
a = np.array(p1)
b = np.array(p2)

class_relative_p = (np.cross(relative_p_arr-a, b-a) <= 0).astype(int)[0]
class_query_points = (np.cross(query_points-a, b-a) <= 0).astype(int)
query_points_mask = class_query_points == class_relative_p

return query_points_mask


def get_split_line_ref_points_img(img, slope_perp, interc_perp):
if slope_perp == np.inf:
...
def get_split_line_ref_points_img(img, slope, interc, xc, yc):
Y, X = img.shape
if slope == np.inf:
x_ref_0 = xc
y_ref_0 = 0
x_ref1 = xc
y_ref1 = Y
elif slope == 0:
x_ref_0 = 0
y_ref_0 = yc
x_ref1 = X
y_ref1 = yc
else:
Y, X = lab.shape
y0 = 0
x0 = y0 - interc_perp/slope_perp
x0 = y0 - interc/slope

x1 = X
y1 = slope_perp*x1 + interc_perp
y1 = slope*x1 + interc

x2 = 0
y2 = interc_perp
y2 = interc

y3 = Y
x3 = y3 - interc_perp/slope_perp
x3 = (y3 - interc)/slope

if x0 < X:
x_ref_0 = x0
Expand All @@ -2682,10 +2745,12 @@ def get_split_line_ref_points_img(img, slope_perp, interc_perp):
x_ref_0 = x1
y_ref_0 = y1

if x1 > 0:
x_ref1 = x1
y_ref1 = y1
if x3 > 0:
x_ref1 = x3
y_ref1 = y3
else:
x_ref1 = x2
y_ref1 = 0
y_ref1 = y2

return (x_ref_0, y_ref_0), (x_ref1, y_ref1)

32 changes: 28 additions & 4 deletions cellacdc/scripts/split_segm_mask_yeast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

from tqdm import tqdm

import numpy as np

import qtpy.compat

from cellacdc import printl, myutils, apps, load, core
from cellacdc._run import _setup_app
from cellacdc.utils.base import NewThreadMultipleExpBaseUtil

DEBUG = True
DEBUG = False

def ask_select_folder():
selected_path = qtpy.compat.getexistingdirectory(
Expand Down Expand Up @@ -91,13 +93,35 @@ def run():
images_path, end_name_acdc_df_file=acdc_df_endname
)
for segm_endname in list_segm_endnames_to_split:
segm_data_to_split = load.load_segm_file(
images_path, end_name_segm_file=segm_endname
segm_data_to_split, segm_data_to_split_fp = load.load_segm_file(
images_path, end_name_segm_file=segm_endname,
return_path=True
)
core.split_segm_masks_mother_bud_line(
out = core.split_segm_masks_mother_bud_line(
cells_segm_data, segm_data_to_split, acdc_df,
debug=DEBUG
)
split_segm_close, split_segm_away = out

segm_data_to_split_fn = os.path.basename(segm_data_to_split_fp)

split_close_filename = segm_data_to_split_fn.replace(
segm_endname, f'{segm_endname}_split_close.npz'
).replace('.npz.npz', '.npz')
split_close_filepath = os.path.join(
images_path, split_close_filename
)

np.savez_compressed(split_close_filepath, split_segm_close)


split_away_filename = segm_data_to_split_fn.replace(
segm_endname, f'{segm_endname}_split_away.npz'
).replace('.npz.npz', '.npz')
split_away_filepath = os.path.join(
images_path, split_away_filename
)
np.savez_compressed(split_away_filepath, split_segm_away)
pbar.update()

pbar.close()
Expand Down

0 comments on commit 0015c4a

Please sign in to comment.