diff --git a/cellacdc/_debug.py b/cellacdc/_debug.py index 75fbf920..a0f3e729 100644 --- a/cellacdc/_debug.py +++ b/cellacdc/_debug.py @@ -3,29 +3,19 @@ import numpy as np import pandas as pd -from . import printl +from . import printl, core -def split_segm_masks_mother_bud_line(lab, obj, obj_bud, interc_perp, slope_perp): +def split_segm_masks_mother_bud_line(lab, obj, obj_bud, ref_p1, ref_p2): import matplotlib.pyplot as plt lab = np.zeros_like(lab) lab[obj.slice][obj.image] = obj.label lab[obj_bud.slice][obj_bud.image] = obj_bud.label - y0 = 0 - x0 = (y0 - interc_perp)/slope_perp - - x1 = lab.shape[1] - y1 = slope_perp*x1 + interc_perp - - x2 = 0 - y2 = interc_perp - - y3 = lab.shape[0] - x3 = (y3 - interc_perp)/slope_perp + (x_ref_0, y_ref_0), (x_ref1, y_ref1) = ref_p1, ref_p2 plt.imshow(lab) - plt.plot([x0, x1, x2, x3], [y0, y1, y2, y3], 'r') + plt.plot([x_ref_0, x_ref1], [y_ref_0, y_ref1], 'r') plt.show() import pdb; pdb.set_trace() diff --git a/cellacdc/core.py b/cellacdc/core.py index 4c1b2c61..68690475 100755 --- a/cellacdc/core.py +++ b/cellacdc/core.py @@ -2616,21 +2616,17 @@ def split_segm_masks_mother_bud_line( try: ccs = acdc_df.at[(frame_i, obj.label), 'cell_cycle_stage'] except Exception as err: - pbar.update() continue if ccs != 'S': - pbar.update() continue try: relationship = acdc_df.at[(frame_i, obj.label), 'relationship'] except Exception as err: - pbar.update() continue if relationship == 'bud': - pbar.update() continue bud_ID = int(acdc_df.at[(frame_i, obj.label), 'relative_ID']) @@ -2644,36 +2640,103 @@ def split_segm_masks_mother_bud_line( if slope_mb != 0: slope_perp = -1/slope_mb interc_perp = yc_m - xc_m*slope_perp + else: + slope_perp = np.inf + interc_perp = np.nan + + ref_p1, ref_p2 = get_split_line_ref_points_img( + lab, slope_perp, interc_perp, xc_m, yc_m + ) if debug: from cellacdc import _debug _debug.split_segm_masks_mother_bud_line( - lab, obj, obj_bud, interc_perp, slope_perp + lab, obj, obj_bud, ref_p1, ref_p2 ) - - - - + for z, lab_split in enumerate(segm_data_to_split[frame_i]): + lab_split_yy, lab_split_xx = np.nonzero(lab_split==obj.label) + if len(lab_split_yy) == 0: + continue + + query_points = np.column_stack((lab_split_xx, lab_split_yy)) + close_to_bud_mask = classify_points_plane_split_by_line( + ref_p1, ref_p2, query_points, (xc_b, yc_b) + ) + + split_close_yy = lab_split_yy[close_to_bud_mask] + split_close_xx = lab_split_xx[close_to_bud_mask] + + split_segm_close[frame_i, z, split_close_yy, split_close_xx] = ( + obj.label + ) + + split_away_yy = lab_split_yy[~close_to_bud_mask] + split_away_xx = lab_split_xx[~close_to_bud_mask] + + split_segm_away[frame_i, z, split_away_yy, split_away_xx] = ( + obj.label + ) + pbar.update() - pbar.close() + pbar.close() + + return split_segm_close, split_segm_away + +def classify_points_plane_split_by_line( + p1, p2, query_points: np.ndarray, relative_to_p + ): + """Classify points on plane crossed by a line connecting p1 and p2 relative + to `relative_to_p` point + + Parameters + ---------- + p1 : (x, y) of floats + First point of the line + p2 : (x, y) of floats + Second point + query_points : (N, 2) np.ndarray + (x, y) coordinates of the points to classify + + References + ---------- + https://stackoverflow.com/questions/45766534/finding-cross-product-to-find-points-above-below-a-line-in-matplotlib + """ + relative_p_arr = np.array([relative_to_p]) + a = np.array(p1) + b = np.array(p2) + + class_relative_p = (np.cross(relative_p_arr-a, b-a) <= 0).astype(int)[0] + class_query_points = (np.cross(query_points-a, b-a) <= 0).astype(int) + query_points_mask = class_query_points == class_relative_p + + return query_points_mask + -def get_split_line_ref_points_img(img, slope_perp, interc_perp): - if slope_perp == np.inf: - ... +def get_split_line_ref_points_img(img, slope, interc, xc, yc): + Y, X = img.shape + if slope == np.inf: + x_ref_0 = xc + y_ref_0 = 0 + x_ref1 = xc + y_ref1 = Y + elif slope == 0: + x_ref_0 = 0 + y_ref_0 = yc + x_ref1 = X + y_ref1 = yc else: - Y, X = lab.shape y0 = 0 - x0 = y0 - interc_perp/slope_perp + x0 = y0 - interc/slope x1 = X - y1 = slope_perp*x1 + interc_perp + y1 = slope*x1 + interc x2 = 0 - y2 = interc_perp + y2 = interc y3 = Y - x3 = y3 - interc_perp/slope_perp + x3 = (y3 - interc)/slope if x0 < X: x_ref_0 = x0 @@ -2682,10 +2745,12 @@ def get_split_line_ref_points_img(img, slope_perp, interc_perp): x_ref_0 = x1 y_ref_0 = y1 - if x1 > 0: - x_ref1 = x1 - y_ref1 = y1 + if x3 > 0: + x_ref1 = x3 + y_ref1 = y3 else: x_ref1 = x2 - y_ref1 = 0 + y_ref1 = y2 + + return (x_ref_0, y_ref_0), (x_ref1, y_ref1) \ No newline at end of file diff --git a/cellacdc/scripts/split_segm_mask_yeast.py b/cellacdc/scripts/split_segm_mask_yeast.py index b774ff0e..3eab10ee 100644 --- a/cellacdc/scripts/split_segm_mask_yeast.py +++ b/cellacdc/scripts/split_segm_mask_yeast.py @@ -2,13 +2,15 @@ from tqdm import tqdm +import numpy as np + import qtpy.compat from cellacdc import printl, myutils, apps, load, core from cellacdc._run import _setup_app from cellacdc.utils.base import NewThreadMultipleExpBaseUtil -DEBUG = True +DEBUG = False def ask_select_folder(): selected_path = qtpy.compat.getexistingdirectory( @@ -91,13 +93,35 @@ def run(): images_path, end_name_acdc_df_file=acdc_df_endname ) for segm_endname in list_segm_endnames_to_split: - segm_data_to_split = load.load_segm_file( - images_path, end_name_segm_file=segm_endname + segm_data_to_split, segm_data_to_split_fp = load.load_segm_file( + images_path, end_name_segm_file=segm_endname, + return_path=True ) - core.split_segm_masks_mother_bud_line( + out = core.split_segm_masks_mother_bud_line( cells_segm_data, segm_data_to_split, acdc_df, debug=DEBUG ) + split_segm_close, split_segm_away = out + + segm_data_to_split_fn = os.path.basename(segm_data_to_split_fp) + + split_close_filename = segm_data_to_split_fn.replace( + segm_endname, f'{segm_endname}_split_close.npz' + ).replace('.npz.npz', '.npz') + split_close_filepath = os.path.join( + images_path, split_close_filename + ) + + np.savez_compressed(split_close_filepath, split_segm_close) + + + split_away_filename = segm_data_to_split_fn.replace( + segm_endname, f'{segm_endname}_split_away.npz' + ).replace('.npz.npz', '.npz') + split_away_filepath = os.path.join( + images_path, split_away_filename + ) + np.savez_compressed(split_away_filepath, split_segm_away) pbar.update() pbar.close()