diff --git a/.gitignore b/.gitignore index b11bf60c99..f309418b3e 100644 --- a/.gitignore +++ b/.gitignore @@ -57,3 +57,8 @@ topostats/_version.py # default output directory, often common from testing output/ + +# Include all files in tests and all subdirectories except processed and __pycache__ +!tests/** +tests/resources/processed/ +__pycache__/ diff --git a/pyproject.toml b/pyproject.toml index 8fa147b54a..6ae49a0c97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,14 +36,15 @@ keywords = [ ] requires-python = ">=3.9" dependencies = [ + "AFMReader", "h5py", "igor2", "keras", "matplotlib", "numpy", "pandas", - "pySPM", "pyfiglet", + "pySPM", "pyyaml", "ruamel.yaml", "schema", @@ -53,7 +54,7 @@ dependencies = [ "skan", "snoop", "tifffile", - "AFMReader", + "topoly", "tqdm", "tensorflow", ] diff --git a/tests/resources/catenane_node_0_avg_image.npy b/tests/resources/catenane_node_0_avg_image.npy new file mode 100644 index 0000000000..a0cfc1e80d Binary files /dev/null and b/tests/resources/catenane_node_0_avg_image.npy differ diff --git a/tests/resources/catenane_node_0_branch_image.npy b/tests/resources/catenane_node_0_branch_image.npy new file mode 100644 index 0000000000..4c5951098f Binary files /dev/null and b/tests/resources/catenane_node_0_branch_image.npy differ diff --git a/tests/resources/catenane_node_0_reduced_node_area.npy b/tests/resources/catenane_node_0_reduced_node_area.npy new file mode 100644 index 0000000000..f1b0db3c80 Binary files /dev/null and b/tests/resources/catenane_node_0_reduced_node_area.npy differ diff --git a/tests/resources/example_catenanes.npy b/tests/resources/example_catenanes.npy new file mode 100644 index 0000000000..fe48e84ccc Binary files /dev/null and b/tests/resources/example_catenanes.npy differ diff --git a/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy b/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy new file mode 100644 index 0000000000..7570f4ae8d Binary files /dev/null and b/tests/resources/example_catenanes_labelled_grain_mask_thresholded.npy differ diff --git a/tests/resources/example_rep_int.npy b/tests/resources/example_rep_int.npy new file mode 100644 index 0000000000..7df4fd5f76 Binary files /dev/null and b/tests/resources/example_rep_int.npy differ diff --git a/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy b/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy new file mode 100644 index 0000000000..182a91e46a Binary files /dev/null and b/tests/resources/example_rep_int_labelled_grain_mask_thresholded.npy differ diff --git a/tests/resources/catenane_all_connected_nodes.npy b/tests/resources/nodestats_analyse_nodes_catenane_all_connected_nodes.npy similarity index 100% rename from tests/resources/catenane_all_connected_nodes.npy rename to tests/resources/nodestats_analyse_nodes_catenane_all_connected_nodes.npy diff --git a/tests/resources/catenane_image_dict.pkl b/tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl similarity index 67% rename from tests/resources/catenane_image_dict.pkl rename to tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl index 2069a2b6de..5f60e204ea 100644 Binary files a/tests/resources/catenane_image_dict.pkl and b/tests/resources/nodestats_analyse_nodes_catenane_image_dict.pkl differ diff --git a/tests/resources/catenane_node_dict.pkl b/tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl similarity index 85% rename from tests/resources/catenane_node_dict.pkl rename to tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl index 4c3b5d433c..33017d01f8 100644 Binary files a/tests/resources/catenane_node_dict.pkl and b/tests/resources/nodestats_analyse_nodes_catenane_node_dict.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl new file mode 100644 index 0000000000..986125e09e Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_all_images.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl new file mode 100644 index 0000000000..54aa6f9c60 Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_crop_data.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv new file mode 100644 index 0000000000..3f5764790d --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,grain_endpoints,grain_junctions,total_branch_lengths +0,test_image,0,0,14,575.5249825836854 +1,test_image,1,0,12,574.7857998943139 diff --git a/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv new file mode 100644 index 0000000000..4e0198f6f8 --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/catenanes_disordered_tracing_stats.csv @@ -0,0 +1,25 @@ +,image,grain_number,branch_distance,branch_type,connected_segments,mean_pixel_value,stdev_pixel_value,min_value,median_value,middle_value +0,test_image,0,29.857677897827887,2,"[1, 2, 6]",2.6683043052890536,0.3576668425125195,1.993486285613394,2.6464969869174926,2.6743964509318303 +1,test_image,0,28.644860587199464,2,"[0, 2, 6]",2.728953130422273,0.21545195325432523,2.5506513793095915,2.6838064728094184,2.6579579145357006 +2,test_image,0,1.8682724368761408,2,"[0, 1, 3, 4]",4.067541277649083,0.16905835685823303,3.7890382980796193,4.138536286707056,4.13958834095104 +3,test_image,0,146.60562184380709,2,"[2, 4, 9]",2.6693913722865377,0.2947637179915079,1.2200007753773205,2.7080175948240655,2.9115862857598005 +4,test_image,0,26.170179495009112,2,"[2, 3, 9]",2.834959559756504,0.3436409841063358,2.535374801513602,2.7607877845412387,2.7651337225310604 +5,test_image,0,224.58506419260596,2,"[6, 7, 10, 11]",2.700995491724554,0.2653435468736685,1.787397162651758,2.7330374920770115,2.2573671026451536 +6,test_image,0,6.9504086553142095,2,"[0, 1, 5, 7]",3.626373814193468,0.6406710041895474,2.588846104614013,3.84195778773634,3.84195778773634 +7,test_image,0,31.203269242513674,2,"[5, 6, 8, 11]",2.843560036090906,0.5572924823368127,2.1307434183593315,2.705164120987648,2.75295411300335 +8,test_image,0,31.286996805637532,2,"[7, 9, 10, 11]",2.7443198761048193,0.3822328907364185,2.4149214282007074,2.653495727717319,2.4706619619066386 +9,test_image,0,5.570136218438069,2,"[3, 4, 8, 10]",2.7308714736136963,0.049010037214155706,2.6525934702850984,2.7378802570467666,2.749457125278597 +10,test_image,0,38.18835899001824,2,"[5, 8, 9, 11]",2.7864423529784395,0.15142352544214144,2.445066671627662,2.7710339081422832,2.6967078592376783 +11,test_image,0,4.59413621843807,2,"[5, 7, 8, 10]",4.126603901043527,0.37133617585304246,3.351114087306773,4.238726491399509,4.238726491399509 +0,test_image,1,37.70035899001824,2,"[1, 2, 3, 6]",2.7732793700561382,0.13491385165238218,2.445066671627662,2.7690167332032654,2.7141994425994813 +1,test_image,1,223.89492797416787,2,"[0, 2, 4, 9]",2.69655183205219,0.26110756906553556,1.8203021919659572,2.7314109425842963,2.2743843560760055 +2,test_image,1,5.08213621843807,2,"[0, 1, 3, 4]",4.029037066510506,0.4638075153095662,3.0770475331177405,4.18502108746829,4.18502108746829 +3,test_image,1,31.001133024075603,2,"[0, 2, 4, 6]",2.746190739871742,0.4020468170580353,2.443926567140926,2.6515637708741715,2.443926567140926 +4,test_image,1,31.405405460951744,2,"[1, 2, 3, 9]",2.8422583900088516,0.5622599874985189,2.1307434183593315,2.694237316611716,2.75295411300335 +5,test_image,1,147.49789428068328,2,"[6, 7, 11]",2.6728103603722104,0.2940846131940484,1.2200007753773205,2.712208521357603,2.9467266502403127 +6,test_image,1,5.77227243687614,2,"[0, 3, 5, 7]",2.7225654392347494,0.051361792194696554,2.6525934702850984,2.7127873939436142,2.7180535973290754 +7,test_image,1,25.682179495009112,2,"[5, 6, 11]",2.85129816853438,0.36309882085275597,2.535374801513602,2.7617733887534093,2.770465264733052 +8,test_image,1,27.870996805637535,2,"[9, 10, 11]",2.7169147220384864,0.1798385572282301,2.5506513793095915,2.6863743944825202,2.6579579145357006 +9,test_image,1,6.9504086553142095,2,"[1, 4, 8, 10]",3.679401125994582,0.5950886464826225,2.6485713386465375,3.84195778773634,3.84195778773634 +10,test_image,1,29.369677897827888,2,"[8, 9, 11]",2.640305654101202,0.3106025836163093,1.993486285613394,2.648283332587627,2.682776277975024 +11,test_image,1,2.558408655314211,2,"[5, 7, 8, 10]",3.974846970762266,0.23923528279875325,3.6040697432150033,4.075122442799481,4.075122442799481 diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl new file mode 100644 index 0000000000..22bd808c7e Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_all_images.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl new file mode 100644 index 0000000000..0cfa26e107 Binary files /dev/null and b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_crop_data.pkl differ diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv new file mode 100644 index 0000000000..f7c3c2e9af --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,grain_endpoints,grain_junctions,total_branch_lengths +0,test_image,0,0,13,968.5225788725928 diff --git a/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv new file mode 100644 index 0000000000..e503c214bb --- /dev/null +++ b/tests/resources/tracing/disordered_tracing/rep_int_disordered_tracing_stats.csv @@ -0,0 +1,13 @@ +,image,grain_number,branch_distance,branch_type,connected_segments,mean_pixel_value,stdev_pixel_value,min_value,median_value,middle_value +0,test_image,0,172.69207377569276,2,"[1, 2, 3, 8]",2.147391874570661,0.18580255309380492,1.4117732114082249,2.147881683634531,2.193944711921703 +1,test_image,0,338.03541679389866,2,"[0, 2, 7, 11]",2.1261598928571943,0.1841696602430932,1.127883594233125,2.1329128300921694,2.236342991231205 +2,test_image,0,0.6901362184380704,2,"[0, 1, 3, 4]",3.5129004230159246,0.2441665496774352,3.2687338733384856,3.5129004230159246,3.5129004230159246 +3,test_image,0,75.41508335877963,2,"[0, 2, 4, 8]",2.1781050277619003,0.22952699769811186,1.4117732114082249,2.178401374885392,2.2734312397952836 +4,test_image,0,31.893405460951744,2,"[2, 3, 5, 6]",2.1378544765518392,0.338975206263122,1.7555674668849865,2.094437231758162,2.2880045741720902 +5,test_image,0,51.51744873752279,2,"[4, 6, 7]",2.233921716584725,0.23610837345626215,1.7265179765777217,2.206985179983061,2.0220519901178635 +6,test_image,0,1.1781362184380704,2,"[4, 5, 7]",3.4335547526985297,0.29158462263860985,3.0268131230204918,3.578160584306969,3.6956905507681284 +7,test_image,0,153.8275289019402,2,"[1, 5, 6, 11]",2.1708118960130536,0.23399433389250326,1.32021443851052,2.1773699140297205,2.1908233129216366 +8,test_image,0,58.753721174398926,2,"[0, 3, 10, 11]",2.0020290183250613,0.3370066648426065,0.8859157914742134,2.08256912431822,2.247454035474802 +9,test_image,0,38.962222771580166,3,[10],2.2735993718579453,0.17804635219989257,2.0015543670262534,2.2595377460435833,2.2032782158757636 +10,test_image,0,0.488,2,"[8, 9, 11]",2.875087133879444,0.009881348191510873,2.865205785688025,2.875087133879444,2.875087133879444 +11,test_image,0,45.069405460951735,2,"[1, 7, 8, 10]",2.1854926308595277,0.22637773949873832,1.289281969607514,2.227074000531621,2.227074000531621 diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl new file mode 100644 index 0000000000..a9b06ffe05 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_all_images.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl new file mode 100644 index 0000000000..676f09ceb3 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_branch_images.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl b/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl new file mode 100644 index 0000000000..2f78a8b297 Binary files /dev/null and b/tests/resources/tracing/nodestats/catenanes_nodestats_data.pkl differ diff --git a/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv b/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv new file mode 100644 index 0000000000..3e0c534882 --- /dev/null +++ b/tests/resources/tracing/nodestats/catenanes_nodestats_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,4,0.4013589828832889,0.2129989376767838 +grain_1,test_image,1,4,0.3441057054647598,0.17063184531586506 diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..6d5cbc25f3 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl new file mode 100644 index 0000000000..6d5cbc25f3 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_all_images_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..a16811b5bf Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl new file mode 100644 index 0000000000..3d12688a43 Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_branch_images_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl new file mode 100644 index 0000000000..b85a1749aa Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_data_no_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl b/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl new file mode 100644 index 0000000000..d00659696d Binary files /dev/null and b/tests/resources/tracing/nodestats/rep_int_nodestats_data_pair_odd_branches.pkl differ diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv new file mode 100644 index 0000000000..c08b2d241c --- /dev/null +++ b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_no_pair_odd_branches.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,5,0.07082753253520613,0.01059564637975774 diff --git a/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv new file mode 100644 index 0000000000..c08b2d241c --- /dev/null +++ b/tests/resources/tracing/nodestats/rep_int_nodestats_grainstats_pair_odd_branches.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_crossings,avg_crossing_confidence,min_crossing_confidence +grain_0,test_image,0,5,0.07082753253520613,0.01059564637975774 diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl new file mode 100644 index 0000000000..29949ab944 Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_data.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl new file mode 100644 index 0000000000..a41553a5e1 Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_full_images.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv new file mode 100644 index 0000000000..722a3a0b22 --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,num_mols,writhe_string +grain_0,catenane,0,2,++++ +grain_1,catenane,1,2,+-++ diff --git a/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv new file mode 100644 index 0000000000..55ecf1c326 --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/catenanes_ordered_tracing_molstats.csv @@ -0,0 +1,5 @@ +,image,grain_number,molecule_number,circular,topology,topology_flip,processing +0,catenane,0,0,True,4^2_1,2^2_1,nodestats +1,catenane,0,1,True,4^2_1,2^2_1,nodestats +2,catenane,1,0,True,2^2_1,0_1U0_1,nodestats +3,catenane,1,1,True,2^2_1,0_1U0_1,nodestats diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl new file mode 100644 index 0000000000..bcfd1062da Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_data.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl new file mode 100644 index 0000000000..4e4b9c550b Binary files /dev/null and b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_full_images.pkl differ diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv new file mode 100644 index 0000000000..d350d7923e --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,num_mols,writhe_string +grain_0,replication_intermediate,0,3,--- diff --git a/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv new file mode 100644 index 0000000000..734d5cc3eb --- /dev/null +++ b/tests/resources/tracing/ordered_tracing/rep_int_ordered_tracing_molstats.csv @@ -0,0 +1,4 @@ +,image,grain_number,molecule_number,circular,topology,topology_flip,processing +0,replication_intermediate,0,0,False,linear,linear,nodestats +1,replication_intermediate,0,1,False,linear,linear,nodestats +2,replication_intermediate,0,2,False,linear,linear,nodestats diff --git a/tests/resources/tracing/splining/catenanes_splining_data.pkl b/tests/resources/tracing/splining/catenanes_splining_data.pkl new file mode 100644 index 0000000000..381d308cad Binary files /dev/null and b/tests/resources/tracing/splining/catenanes_splining_data.pkl differ diff --git a/tests/resources/tracing/splining/catenanes_splining_grainstats.csv b/tests/resources/tracing/splining/catenanes_splining_grainstats.csv new file mode 100644 index 0000000000..cf0f3cc930 --- /dev/null +++ b/tests/resources/tracing/splining/catenanes_splining_grainstats.csv @@ -0,0 +1,3 @@ +,image,grain_number,total_contour_length,average_end_to_end_distance +grain_0,catenane,0,1113.3149766322022,0.0 +grain_1,catenane,1,1113.4528311181873,0.0 diff --git a/tests/resources/tracing/splining/catenanes_splining_molstats.csv b/tests/resources/tracing/splining/catenanes_splining_molstats.csv new file mode 100644 index 0000000000..c44c187f95 --- /dev/null +++ b/tests/resources/tracing/splining/catenanes_splining_molstats.csv @@ -0,0 +1,5 @@ +,image,grain_number,molecule_number,contour_length,end_to_end_distance +0,catenane,0,0,846.6004494307274,0 +1,catenane,0,1,266.7145272014749,0 +2,catenane,1,0,846.5829590713623,0 +3,catenane,1,1,266.869872046825,0 diff --git a/tests/resources/tracing/splining/rep_int_splining_data.pkl b/tests/resources/tracing/splining/rep_int_splining_data.pkl new file mode 100644 index 0000000000..d5982ae8c7 Binary files /dev/null and b/tests/resources/tracing/splining/rep_int_splining_data.pkl differ diff --git a/tests/resources/tracing/splining/rep_int_splining_grainstats.csv b/tests/resources/tracing/splining/rep_int_splining_grainstats.csv new file mode 100644 index 0000000000..9e98c7ae0d --- /dev/null +++ b/tests/resources/tracing/splining/rep_int_splining_grainstats.csv @@ -0,0 +1,2 @@ +,image,grain_number,total_contour_length,average_end_to_end_distance +grain_0,replication_intermediate,0,1773.7493902268593,165.97964434071082 diff --git a/tests/resources/tracing/splining/rep_int_splining_molstats.csv b/tests/resources/tracing/splining/rep_int_splining_molstats.csv new file mode 100644 index 0000000000..3a400e53d7 --- /dev/null +++ b/tests/resources/tracing/splining/rep_int_splining_molstats.csv @@ -0,0 +1,4 @@ +,image,grain_number,molecule_number,contour_length,end_to_end_distance +0,replication_intermediate,0,0,748.2001882005334,167.88686666919483 +1,replication_intermediate,0,1,766.7436650901011,167.0029939851379 +2,replication_intermediate,0,2,258.80553693622494,163.04907236779974 diff --git a/tests/test_io.py b/tests/test_io.py index 9d26bc4898..8940cfcfb0 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -15,6 +15,7 @@ from topostats.io import ( LoadScans, convert_basename_to_relative_paths, + dict_almost_equal, dict_to_hdf5, find_files, get_date_time, @@ -59,48 +60,6 @@ # pylint: disable=too-many-lines -def dict_almost_equal(dict1, dict2, abs_tol=1e-9): - """Recursively check if two dictionaries are almost equal with a given absolute tolerance. - - Parameters - ---------- - dict1: dict - First dictionary to compare. - dict2: dict - Second dictionary to compare. - abs_tol: float - Absolute tolerance to check for equality. - - Returns - ------- - bool - True if the dictionaries are almost equal, False otherwise. - """ - if dict1.keys() != dict2.keys(): - return False - - LOGGER.info("Comparing dictionaries") - - for key in dict1: - LOGGER.info(f"Comparing key {key}") - if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): - if not dict_almost_equal(dict1[key], dict2[key], abs_tol=abs_tol): - return False - elif isinstance(dict1[key], np.ndarray) and isinstance(dict2[key], np.ndarray): - if not np.allclose(dict1[key], dict2[key], atol=abs_tol): - LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") - return False - elif isinstance(dict1[key], float) and isinstance(dict2[key], float): - if not np.isclose(dict1[key], dict2[key], atol=abs_tol): - LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") - return False - elif dict1[key] != dict2[key]: - LOGGER.info(f"Key {key} not equal: {dict1[key]} != {dict2[key]}") - return False - - return True - - def test_get_date_time() -> None: """Test the fetching of a formatted date and time string.""" assert datetime.strptime(get_date_time(), "%Y-%m-%d %H:%M:%S") @@ -281,6 +240,13 @@ def test_load_array() -> None: False, id="float not equal", ), + pytest.param( + {"a": np.nan}, + {"a": np.nan}, + 0.0001, + True, + id="nan equal", + ), ], ) def test_dict_almost_equal(dict1: dict, dict2: dict, tolerance: float, expected: bool) -> None: @@ -769,7 +735,7 @@ def test_dict_to_hdf5_all_together_group_path_non_standard(tmp_path: Path) -> No assert list(f.keys()) == list(expected.keys()) assert f["d"]["a"][()] == expected["d"]["a"] np.testing.assert_array_equal(f["d"]["b"][()], expected["d"]["b"]) - assert f["d"]["c"][()].decode("utf-8") == expected["d"]["c"] + assert f["d"]["c"][()].decode("utf-8") == expected["d"]["c"] # pylint: disable=no-member assert f["d"]["d"]["e"][()] == expected["d"]["d"]["e"] np.testing.assert_array_equal(f["d"]["d"]["f"][()], expected["d"]["d"]["f"]) assert f["d"]["d"]["g"][()].decode("utf-8") == expected["d"]["d"]["g"] diff --git a/tests/tracing/conftest.py b/tests/tracing/conftest.py index 0d54374152..514f2f8338 100644 --- a/tests/tracing/conftest.py +++ b/tests/tracing/conftest.py @@ -1,6 +1,5 @@ """Fixtures for the tracing tests.""" -import pickle from pathlib import Path import numpy as np @@ -42,7 +41,7 @@ @pytest.fixture() def test_dnatracing() -> dnaTrace: """Instantiate a dnaTrace object.""" - return dnaTrace(image=FULL_IMAGE, grain=GRAINS, filename="Test", pixel_to_nm_scaling=1.0) + return dnaTrace(image=FULL_IMAGE, mask=GRAINS, filename="Test", pixel_to_nm_scaling=1.0) @pytest.fixture() @@ -165,18 +164,6 @@ def catenane_image() -> npt.NDArray[np.number]: return np.load(RESOURCES / "catenane_image.npy") -@pytest.fixture() -def catenane_skeleton() -> npt.NDArray[np.bool_]: - """Skeleton of the catenane test image.""" - return np.load(RESOURCES / "catenane_skeleton.npy") - - -@pytest.fixture() -def catenane_smoothed_mask() -> npt.NDArray[np.bool_]: - """Smoothed mask of the catenane test image.""" - return np.load(RESOURCES / "catenane_smoothed_mask.npy") - - @pytest.fixture() def catenane_node_centre_mask() -> npt.NDArray[np.int32]: """ @@ -202,12 +189,13 @@ def catenane_connected_nodes() -> npt.NDArray[np.int32]: @pytest.fixture() def nodestats_catenane( catenane_image: npt.NDArray[np.number], - catenane_smoothed_mask: npt.NDArray[np.bool_], - catenane_skeleton: npt.NDArray[np.bool_], - catenane_node_centre_mask: npt.NDArray[np.int32], - catenane_connected_nodes: npt.NDArray[np.int32], ) -> nodeStats: """Fixture for the nodeStats object for a catenated molecule, to be used in analyse_nodes.""" + catenane_smoothed_mask: npt.NDArray[np.bool_] = np.load(RESOURCES / "catenane_smoothed_mask.npy") + catenane_skeleton: npt.NDArray[np.bool_] = np.load(RESOURCES / "catenane_skeleton.npy") + catenane_node_centre_mask = np.load(RESOURCES / "catenane_node_centre_mask.npy") + catenane_connected_nodes = np.load(RESOURCES / "catenane_connected_nodes.npy") + # Create a nodestats object nodestats = nodeStats( filename="test_catenane", @@ -215,11 +203,12 @@ def nodestats_catenane( mask=catenane_smoothed_mask, smoothed_mask=catenane_smoothed_mask, skeleton=catenane_skeleton, - px_2_nm=np.float64(0.18124609375), + pixel_to_nm_scaling=np.float64(0.18124609375), n_grain=1, node_joining_length=7, node_extend_dist=14.0, branch_pairing_length=20.0, + pair_odd_branches=True, ) nodestats.node_centre_mask = catenane_node_centre_mask @@ -227,25 +216,3 @@ def nodestats_catenane( nodestats.skeleton = catenane_skeleton return nodestats - - -# pylint: disable=unspecified-encoding -@pytest.fixture() -def nodestats_catenane_node_dict() -> dict: - """Node dictionary for the catenane test image.""" - with Path.open(RESOURCES / "catenane_node_dict.pkl", "rb") as file: - return pickle.load(file) # noqa: S301 - Pickles unsafe but we don't care - - -# pylint: disable=unspecified-encoding -@pytest.fixture() -def nodestats_catenane_image_dict() -> dict: - """Image dictionary for the catenane test image.""" - with Path.open(RESOURCES / "catenane_image_dict.pkl", "rb") as file: - return pickle.load(file) # noqa: S301 - Pickles unsafe but we don't care - - -@pytest.fixture() -def nodestats_catenane_all_connected_nodes() -> npt.NDArray[np.int32]: - """All connected nodes for the catenane test image.""" - return np.load(RESOURCES / "catenane_all_connected_nodes.npy") diff --git a/tests/tracing/test_disordered_tracing.py b/tests/tracing/test_disordered_tracing.py new file mode 100644 index 0000000000..e402e104fa --- /dev/null +++ b/tests/tracing/test_disordered_tracing.py @@ -0,0 +1,1296 @@ +# Disable ruff 301 - pickle loading is unsafe but we don't care for tests +# ruff: noqa: S301 +"""Test the disordered tracing module.""" + +import pickle as pkl +from pathlib import Path + +import numpy as np +import numpy.typing as npt +import pandas as pd +import pytest + +from topostats.io import dict_almost_equal # pylint: disable=no-name-in-module import-error +from topostats.tracing.disordered_tracing import disordered_trace_grain, trace_image_disordered + +# pylint: disable=too-many-arguments +# pylint: disable=too-many-locals +# pylint: disable=too-many-lines +# pylint: disable=unspecified-encoding + +BASE_DIR = Path.cwd() +DISORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "disordered_tracing" +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" + + +@pytest.mark.parametrize( + ( + "cropped_image", + "cropped_mask", + "pixel_to_nm_scaling", + "mask_smoothing_params", + "skeletonisation_params", + "pruning_params", + "filename", + "min_skeleton_size", + "expected_smoothed_grain", + "expected_skeleton", + "expected_pruned_skeleton", + "expected_branch_types", + ), + [ + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": -1, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ), + id="simple slightly curved line", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.2, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": -1, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.float32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + id="test height bias: thick curve height weighting outer, strong height bias", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.8, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": -1, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.float32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + id="test height bias: thick curve height weighting outer, weak height bias", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 2, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test pruning: thick curve with tail, no pruning", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 2, 4, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 2, 4, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 2, 4, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [0, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 10, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test pruning: thick curve with tail, prune small branch", + ), + pytest.param( + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.float32), + # Cropped mask + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 4, 3, 3, 3, 2, 2, 2, 1, 1, 0, 0, 1, 4, 1, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 2, 4, 3, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4, 2, 0, 0], + [0, 0, 1, 4, 3, 2, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 4, 2, 0, 0], + [0, 0, 0, 1, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ] + ).astype(np.bool_), + # Pixel to nm scaling + 1.0, + # Mask smoothing parameters + { + "gaussian_sigma": 1, + "dilation_iterations": 1, + "holearea_min_max": [3, 5], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": -1, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + # Filename + "test_image", + # Min skeleton size + 6, + # Expected smoothed grain + # Cropped image + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int32, + ), + # Expected skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected pruned skeleton + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.bool_, + ), + # Expected branch types + np.array( + [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 0, 0, 0, 0], + [0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 3, 3, 0, 0, 0, 3, 0, 0, 0], + [0, 0, 0, 0, 2, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 0, 3, 3, 3, 3, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ], + dtype=np.int8, + ), + id="test re-add holes, 3 holes, one right size for re-adding", + ), + ], +) +def test_disordered_trace_grain( + cropped_image: npt.NDArray[np.float32], + cropped_mask: npt.NDArray[np.bool_], + pixel_to_nm_scaling: float, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + filename: str, + min_skeleton_size: int, + expected_smoothed_grain: npt.NDArray[np.bool_], + expected_skeleton: npt.NDArray[np.bool_], + expected_pruned_skeleton: npt.NDArray[np.bool_], + expected_branch_types: npt.NDArray[np.int32], +) -> None: + """Test the disorderedTrace() method.""" + result_dict = disordered_trace_grain( + cropped_image=cropped_image, + cropped_mask=cropped_mask, + pixel_to_nm_scaling=pixel_to_nm_scaling, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + filename=filename, + min_skeleton_size=min_skeleton_size, + ) + + result_smoothed_grain = result_dict["smoothed_grain"] + result_skeleton = result_dict["skeleton"] + result_pruned_skeleton = result_dict["pruned_skeleton"] + result_branch_types = result_dict["branch_types"] + + np.testing.assert_array_equal(result_smoothed_grain, expected_smoothed_grain) + np.testing.assert_array_equal(result_skeleton, expected_skeleton) + np.testing.assert_array_equal(result_pruned_skeleton, expected_pruned_skeleton) + np.testing.assert_array_equal(result_branch_types, expected_branch_types) + + +@pytest.mark.parametrize( + ( + "image_filename", + "mask_filename", + "pixel_to_nm_scaling", + "min_skeleton_size", + "mask_smoothing_params", + "skeletonisation_params", + "pruning_params", + "expected_disordered_crop_data_filename", + "expected_disordered_tracing_grainstats_filename", + "expected_all_images_filename", + "expected_disordered_tracing_stats_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "example_catenanes_labelled_grain_mask_thresholded.npy", + # Pixel to nm scaling + 0.488, + # Min skeleton size + 10, + # Mask smoothing parameters + { + "gaussian_sigma": 2, + "dilation_iterations": 2, + "holearea_min_max": [10, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 7.0, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + "catenanes_disordered_tracing_crop_data.pkl", + "catenanes_disordered_tracing_grainstats.csv", + "catenanes_disordered_tracing_all_images.pkl", + "catenanes_disordered_tracing_stats.csv", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "example_rep_int_labelled_grain_mask_thresholded.npy", + # Pixel to nm scaling + 0.488, + # Min skeleton size + 10, + # Mask smoothing parameters + { + "gaussian_sigma": 2, + "dilation_iterations": 2, + "holearea_min_max": [10, None], + }, + # Skeletonisation parameters + { + "method": "topostats", + "height_bias": 0.6, + }, + # Pruning parameters + { + "method": "topostats", + "max_length": 20.0, + "height_threshold": None, + "method_values": "mid", + "method_outlier": "mean_abs", + }, + "rep_int_disordered_tracing_crop_data.pkl", + "rep_int_disordered_tracing_grainstats.csv", + "rep_int_disordered_tracing_all_images.pkl", + "rep_int_disordered_tracing_stats.csv", + id="replication intermediate", + ), + ], +) +def test_trace_image_disordered( + image_filename: str, + mask_filename: str, + pixel_to_nm_scaling: float, + min_skeleton_size: int, + mask_smoothing_params: dict, + skeletonisation_params: dict, + pruning_params: dict, + expected_disordered_crop_data_filename: str, + expected_disordered_tracing_grainstats_filename: str, + expected_all_images_filename: str, + expected_disordered_tracing_stats_filename: str, +) -> None: + """Test the trace image disordered method.""" + # Load the image + image = np.load(GENERAL_RESOURCES / image_filename) + mask = np.load(GENERAL_RESOURCES / mask_filename) + + ( + result_disordered_crop_data, + result_disordered_tracing_grainstats, + result_all_images, + result_disordered_tracing_stats, + ) = trace_image_disordered( + image=image, + grains_mask=mask, + filename="test_image", + pixel_to_nm_scaling=pixel_to_nm_scaling, + min_skeleton_size=min_skeleton_size, + mask_smoothing_params=mask_smoothing_params, + skeletonisation_params=skeletonisation_params, + pruning_params=pruning_params, + pad_width=1, + ) + + # DEBUGGING CODE + # Turning sub-structures into variables to be able to be inspected + # variable_smoothed_grain = result_all_images["smoothed_grain"] + # variable_skeleton = result_all_images["skeleton"] + # variable_pruned_skeleton = result_all_images["pruned_skeleton"] + # variable_branch_types = result_all_images["branch_types"] + + # Update expected values - CHECK RESULTS WITH EXPERT BEFORE UPDATING + # Pickle result_disordered_crop_data + # with open(DISORDERED_TRACING_RESOURCES / expected_disordered_crop_data_filename, "wb") as f: + # pkl.dump(result_disordered_crop_data, f) + + # # Save result_disordered_tracing_grainstats as a csv + # result_disordered_tracing_grainstats.to_csv( + # DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_grainstats_filename + # ) + + # # Save result_all_images as a pickle + # with open(DISORDERED_TRACING_RESOURCES / expected_all_images_filename, "wb") as f: + # pkl.dump(result_all_images, f) + + # # Save result_disordered_tracing_stats dataframe as a csv + # result_disordered_tracing_stats.to_csv(DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_stats_filename) + + # Load expected values + with Path.open(DISORDERED_TRACING_RESOURCES / expected_disordered_crop_data_filename, "rb") as f: + expected_disordered_crop_data = pkl.load(f) + + expected_disordered_tracing_grainstats = pd.read_csv( + DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_grainstats_filename, index_col=0 + ) + + with Path.open(DISORDERED_TRACING_RESOURCES / expected_all_images_filename, "rb") as f: + expected_all_images = pkl.load(f) + + expected_disordered_tracing_stats = pd.read_csv( + DISORDERED_TRACING_RESOURCES / expected_disordered_tracing_stats_filename, index_col=0 + ) + + assert dict_almost_equal(result_disordered_crop_data, expected_disordered_crop_data, abs_tol=1e-11) + pd.testing.assert_frame_equal(result_disordered_tracing_grainstats, expected_disordered_tracing_grainstats) + assert dict_almost_equal(result_all_images, expected_all_images, abs_tol=1e-11) + pd.testing.assert_frame_equal(result_disordered_tracing_stats, expected_disordered_tracing_stats) diff --git a/tests/tracing/test_nodestats.py b/tests/tracing/test_nodestats.py index 10bab2f20c..17d5ae5eb4 100644 --- a/tests/tracing/test_nodestats.py +++ b/tests/tracing/test_nodestats.py @@ -7,19 +7,25 @@ import numpy as np import numpy.typing as npt +import pandas as pd import pytest from pytest_lazyfixture import lazy_fixture -from topostats.tracing.nodestats import nodeStats +# pylint: disable=import-error +# pylint: disable=no-name-in-module +from topostats.io import dict_almost_equal +from topostats.tracing.nodestats import nodeStats, nodestats_image BASE_DIR = Path.cwd() -RESOURCES = BASE_DIR / "tests" / "resources" - +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +DISORDERED_TRACING_RESOURCES = GENERAL_RESOURCES / "tracing" / "disordered_tracing" +NODESTATS_RESOURCES = GENERAL_RESOURCES / "tracing" / "nodestats" # from topostats.tracing.nodestats import nodeStats # pylint: disable=unnecessary-pass # pylint: disable=too-many-arguments # pylint: disable=too-many-locals +# pylint: disable=too-many-lines # @pytest.mark.parametrize() @@ -164,11 +170,12 @@ def test_connect_extended_nodes_nearest( mask=np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), smoothed_mask=np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]]), skeleton=connected_nodes.astype(bool), - px_2_nm=np.float64(1.0), + pixel_to_nm_scaling=np.float64(1.0), n_grain=0, node_joining_length=0.0, node_extend_dist=14.0, branch_pairing_length=20.0, + pair_odd_branches=True, ) nodestats.whole_skel_graph = nodestats.skeleton_image_to_graph(nodestats.skeleton) result = nodestats.connect_extended_nodes_nearest(connected_nodes, node_extend_dist=8.0) @@ -184,9 +191,6 @@ def test_find_branch_starts() -> None: # Create nodestats class using the cats image - will allow running the code for diagnostics def test_analyse_nodes( nodestats_catenane: nodeStats, - nodestats_catenane_node_dict: dict, - nodestats_catenane_image_dict: dict, - nodestats_catenane_all_connected_nodes: npt.NDArray[np.int32], ) -> None: """Test of analyse_nodes() method of nodeStats class.""" nodestats_catenane.analyse_nodes(max_branch_length=20) @@ -194,32 +198,36 @@ def test_analyse_nodes( node_dict_result = nodestats_catenane.node_dicts image_dict_result = nodestats_catenane.image_dict - # Nodestats dict has structure: - # "node_1": - # - error: Bool - # - px_2_nm: float - # - crossing_type: None - # - branch_starts: dict: - # - ordered coords: array Nx2 - # - heights: array Nx2 - # - distances: array Nx2 - # - fwhm2: tuple(float, list(3?), list (3?)) - - # Image dict has structure: - # - nodes - # - node_1: dict - # - node_area_skeleton: array NxN - # - node_branch_mask: array NxN - # - node_average_mask: array NxN - # - node_2 ... - # - grain - # - grain_image: array NxN - # - grain_mask: array NxN - # - grain_skeleton: array NxN - - np.testing.assert_equal(node_dict_result, nodestats_catenane_node_dict) - np.testing.assert_equal(image_dict_result, nodestats_catenane_image_dict) - np.testing.assert_array_equal(nodestats_catenane.all_connected_nodes, nodestats_catenane_all_connected_nodes) + # Debugging + # Save the results to overwrite expected results + # with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_node_dict.pkl").open("wb") as f: + # pickle.dump(node_dict_result, f) + + # with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_image_dict.pkl").open("wb") as f: + # pickle.dump(image_dict_result, f) + + # np.save( + # GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_all_connected_nodes.npy", + # nodestats_catenane.all_connected_nodes, + # ) + + # Load the nodestats catenane node dict from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_node_dict.pkl").open("rb") as f: + expected_nodestats_catenane_node_dict = pickle.load(f) + + # Load the nodestats catenane image dict from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_image_dict.pkl").open("rb") as f: + expected_nodestats_catenane_image_dict = pickle.load(f) + + # Load the nodestats catenane all connected nodes from pickle + with Path(GENERAL_RESOURCES / "nodestats_analyse_nodes_catenane_all_connected_nodes.npy").open("rb") as f: + expected_nodestats_catenane_all_connected_nodes = np.load(f) + + np.testing.assert_equal(node_dict_result, expected_nodestats_catenane_node_dict) + np.testing.assert_equal(image_dict_result, expected_nodestats_catenane_image_dict) + np.testing.assert_array_equal( + nodestats_catenane.all_connected_nodes, expected_nodestats_catenane_all_connected_nodes + ) @pytest.mark.parametrize( @@ -264,22 +272,22 @@ def test_add_branches_to_labelled_image( ) -> None: """Test of add_branches_to_labelled_image() method of nodeStats class.""" # Load the matched branches - with Path(RESOURCES / f"{matched_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{matched_branches_filename}").open("rb") as f: matched_branches: dict[int, dict[str, npt.NDArray[np.number]]] = pickle.load(f) # Load the masked image - with Path(RESOURCES / f"{masked_image_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{masked_image_filename}").open("rb") as f: masked_image: dict[int, dict[str, npt.NDArray[np.bool_]]] = pickle.load(f) # Load the ordered branches - with Path(RESOURCES / f"{ordered_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{ordered_branches_filename}").open("rb") as f: ordered_branches: list[npt.NDArray[np.int32]] = pickle.load(f) # Load the branch image - expected_branch_image: npt.NDArray[np.int32] = np.load(RESOURCES / expected_branch_image_filename) + expected_branch_image: npt.NDArray[np.int32] = np.load(GENERAL_RESOURCES / expected_branch_image_filename) # Load the average image - expected_average_image: npt.NDArray[np.float64] = np.load(RESOURCES / expected_average_image_filename) + expected_average_image: npt.NDArray[np.float64] = np.load(GENERAL_RESOURCES / expected_average_image_filename) result_branch_image, result_average_image = nodeStats.add_branches_to_labelled_image( branch_under_over_order=branch_under_over_order, @@ -296,6 +304,7 @@ def test_add_branches_to_labelled_image( np.testing.assert_equal(result_average_image, expected_average_image) +# FIXME Need a test for not pairing odd branches. Will need a test image with 3-nodes. @pytest.mark.parametrize( ( "p_to_nm", @@ -306,6 +315,7 @@ def test_add_branches_to_labelled_image( "image", "average_trace_advised", "node_coord", + "pair_odd_branches", "filename", "resolution_threshold", "expected_pairs", @@ -314,6 +324,7 @@ def test_add_branches_to_labelled_image( "expected_masked_image_filename", "expected_branch_under_over_order", "expected_conf", + "expected_singlet_branch_vectors", ), [ pytest.param( @@ -325,6 +336,7 @@ def test_add_branches_to_labelled_image( lazy_fixture("catenane_image"), True, (np.int32(280), np.int32(353)), + True, "catenane_test_image", np.float64(1000 / 512), np.array([(1, 3), (2, 0)]), @@ -333,6 +345,12 @@ def test_add_branches_to_labelled_image( "catenane_node_0_masked_image.pkl", np.array([0, 1]), 0.48972025484111525, + [ + np.array([-0.97044686, -0.24131493]), + np.array([0.10375883, -0.99460249]), + np.array([0.98972257, -0.14300081]), + np.array([0.46367343, 0.88600618]), + ], id="node 0", ) ], @@ -346,6 +364,7 @@ def test_analyse_node_branches( image: npt.NDArray[np.float64], average_trace_advised: bool, node_coord: tuple[np.int32, np.int32], + pair_odd_branches: np.bool_, filename: str, resolution_threshold: np.float64, expected_pairs: npt.NDArray[np.int32], @@ -354,13 +373,14 @@ def test_analyse_node_branches( expected_masked_image_filename: str, expected_branch_under_over_order: npt.NDArray[np.int32], expected_conf: float, + expected_singlet_branch_vectors: list[npt.NDArray[np.int32]], ) -> None: """Test of analyse_node_branches() method of nodeStats class.""" # Load the reduced node area - reduced_node_area = np.load(RESOURCES / f"{reduced_node_area_filename}") + reduced_node_area = np.load(GENERAL_RESOURCES / f"{reduced_node_area_filename}") # Load the reduced skeleton graph - with Path(RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: reduced_skeleton_graph = pickle.load(f) ( @@ -370,6 +390,7 @@ def test_analyse_node_branches( result_masked_image, result_branch_idx_order, result_conf, + result_singlet_branch_vectors, ) = nodeStats.analyse_node_branches( p_to_nm=np.float64(p_to_nm), reduced_node_area=reduced_node_area, @@ -379,19 +400,20 @@ def test_analyse_node_branches( image=image, average_trace_advised=average_trace_advised, node_coord=node_coord, + pair_odd_branches=pair_odd_branches, filename=filename, resolution_threshold=resolution_threshold, ) # Load expected matched branches - with Path(RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: expected_matched_branches = pickle.load(f) # Load expected masked image - with Path(RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: expected_masked_image = pickle.load(f) # Load expected ordered branches - with Path(RESOURCES / f"{expected_ordered_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{expected_ordered_branches_filename}").open("rb") as f: expected_ordered_branches = pickle.load(f) np.testing.assert_equal(result_pairs, expected_pairs) @@ -400,6 +422,7 @@ def test_analyse_node_branches( np.testing.assert_equal(result_masked_image, expected_masked_image) np.testing.assert_equal(result_branch_idx_order, expected_branch_under_over_order) np.testing.assert_almost_equal(result_conf, expected_conf, decimal=6) + np.testing.assert_almost_equal(result_singlet_branch_vectors, expected_singlet_branch_vectors, decimal=6) @pytest.mark.parametrize( @@ -490,19 +513,19 @@ def test_join_matching_branches_through_node( ) -> None: """Test of join_matching_branches_through_node() method of nodeStats class.""" # Load the ordered branches - with Path(RESOURCES / f"{ordered_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{ordered_branches_filename}").open("rb") as f: ordered_branches = pickle.load(f) # Load the reduced skeleton graph - with Path(RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{reduced_skeleton_graph_filename}").open("rb") as f: reduced_skeleton_graph = pickle.load(f) # Load expected matched branches - with Path(RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{expected_matched_branches_filename}").open("rb") as f: expected_matched_branches = pickle.load(f) # Load expected masked image - with Path(RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: + with Path(GENERAL_RESOURCES / f"{expected_masked_image_filename}").open("rb") as f: expected_masked_image = pickle.load(f) result_matched_branches, result_masked_image = nodeStats.join_matching_branches_through_node( @@ -583,11 +606,6 @@ def test_auc() -> None: pass -def test_get_two_combinations() -> None: - """Test of get_two_combinations() method of nodeStats class.""" - pass - - def test_cross_confidence() -> None: """Test of cross_confidence() method of nodeStats class.""" pass @@ -849,3 +867,161 @@ def test_average_crossing_confs() -> None: def test_minimum_crossing_confs() -> None: """Test minimum_crossing_confs() method of nodeStats class.""" pass + + +@pytest.mark.parametrize( + ( + "image_filename", + "pixel_to_nm_scaling", + "disordered_tracing_crop_data_filename", + "node_joining_length", + "node_extend_dist", + "branch_pairing_length", + "pair_odd_branches", + "expected_nodestats_data_filename", + "expected_nodestats_grainstats_filename", + "expected_nodestats_all_images_filename", + "expected_nodestats_branch_images_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + # Pixel to nm scaling + 0.488, + "catenanes_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + True, + "catenanes_nodestats_data.pkl", + "catenanes_nodestats_grainstats.csv", + "catenanes_nodestats_all_images.pkl", + "catenanes_nodestats_branch_images.pkl", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + # Pixel to nm scaling + 0.488, + "rep_int_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + False, + "rep_int_nodestats_data_no_pair_odd_branches.pkl", + "rep_int_nodestats_grainstats_no_pair_odd_branches.csv", + "rep_int_nodestats_all_images_no_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_no_pair_odd_branches.pkl", + id="replication_intermediate, not pairing odd branches", + ), + pytest.param( + "example_rep_int.npy", + # Pixel to nm scaling + 0.488, + "rep_int_disordered_tracing_crop_data.pkl", + # Node joining length + 7.0, + # Node extend distance + 14.0, + # Branch pairing length + 20.0, + # Pair odd branches + True, + "rep_int_nodestats_data_pair_odd_branches.pkl", + "rep_int_nodestats_grainstats_pair_odd_branches.csv", + "rep_int_nodestats_all_images_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_pair_odd_branches.pkl", + id="replication_intermediate, pairing odd branches", + ), + ], +) +def test_nodestats_image( + image_filename: str, + pixel_to_nm_scaling: float, + disordered_tracing_crop_data_filename: str, + node_joining_length: float, + node_extend_dist: float, + branch_pairing_length: float, + pair_odd_branches: bool, + expected_nodestats_data_filename: str, + expected_nodestats_grainstats_filename: str, + expected_nodestats_all_images_filename: str, + expected_nodestats_branch_images_filename: str, +) -> None: + """Test of nodestats_image() method of nodeStats class.""" + # Load the image + image = np.load(GENERAL_RESOURCES / image_filename) + # load disordered_tracing_crop_data from pickle + with Path(DISORDERED_TRACING_RESOURCES / disordered_tracing_crop_data_filename).open("rb") as f: + disordered_tracing_crop_data = pickle.load(f) + + ( + result_nodestats_data, + result_nodestats_grainstats, + result_nodestats_all_images, + result_nodestats_branch_images, + ) = nodestats_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_crop_data, + filename="test_image", + pixel_to_nm_scaling=pixel_to_nm_scaling, + node_joining_length=node_joining_length, + node_extend_dist=node_extend_dist, + branch_pairing_length=branch_pairing_length, + pair_odd_branches=pair_odd_branches, + pad_width=1, + ) + + # # DEBUGGING (For viewing images) + # convolved_skeletons = result_all_images["convolved_skeletons"] + # node_centres = result_all_images["node_centres"] + # connected_nodes = result_all_images["connected_nodes"] + + # Save the results + + # Save the result_nodestats_data + with Path(NODESTATS_RESOURCES / expected_nodestats_data_filename).open("wb") as f: + pickle.dump(result_nodestats_data, f) + + # Save the result_stats_df as a csv + # result_nodestats_grainstats.to_csv(NODESTATS_RESOURCES / expected_nodestats_grainstats_filename) + + # # Save the result_all_images + # with Path(NODESTATS_RESOURCES / expected_nodestats_all_images_filename).open("wb") as f: + # pickle.dump(result_nodestats_all_images, f) + + # # Save the result_nodestats_branch_images + # with Path(NODESTATS_RESOURCES / expected_nodestats_branch_images_filename).open("wb") as f: + # pickle.dump(result_nodestats_branch_images, f) + + # Load expected data + + # Load the expected nodestats data + with Path(NODESTATS_RESOURCES / expected_nodestats_data_filename).open("rb") as f: + expected_nodestats_data = pickle.load(f) + + # Load the expected grainstats additions + expected_nodestats_grainstats = pd.read_csv( + NODESTATS_RESOURCES / expected_nodestats_grainstats_filename, index_col=0 + ) + + # Load the expected all images + with Path(NODESTATS_RESOURCES / expected_nodestats_all_images_filename).open("rb") as f: + expected_all_images = pickle.load(f) + + # Load the expected nodestats branch images + with Path(NODESTATS_RESOURCES / expected_nodestats_branch_images_filename).open("rb") as f: + expected_nodestats_branch_images = pickle.load(f) + + assert dict_almost_equal(result_nodestats_data, expected_nodestats_data, abs_tol=1e-3) + pd.testing.assert_frame_equal(result_nodestats_grainstats, expected_nodestats_grainstats) + assert dict_almost_equal(result_nodestats_all_images, expected_all_images) + assert dict_almost_equal(result_nodestats_branch_images, expected_nodestats_branch_images) diff --git a/tests/tracing/test_ordered_tracing.py b/tests/tracing/test_ordered_tracing.py new file mode 100644 index 0000000000..a2ddd1f3ee --- /dev/null +++ b/tests/tracing/test_ordered_tracing.py @@ -0,0 +1,146 @@ +# Disable ruff 301 - pickle loading is unsafe, but we don't care for tests +# ruff: noqa: S301 +"""Test the ordered tracing module.""" + +import pickle +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest + +from topostats.tracing.ordered_tracing import ordered_tracing_image + +BASE_DIR = Path.cwd() +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +ORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "ordered_tracing" +NODESTATS_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "nodestats" +DISORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "disordered_tracing" + +# pylint: disable=unspecified-encoding +# pylint: disable=too-many-locals +# pylint: disable=too-many-arguments + + +@pytest.mark.parametrize( + ( + "image_filename", + "disordered_tracing_direction_data_filename", + "nodestats_data_filename", + "nodestats_branch_images_filename", + "filename", + "expected_ordered_tracing_data_filename", + "expected_ordered_tracing_grainstats_filename", + "expected_molstats_filename", + "expected_ordered_tracing_full_images_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "catenanes_disordered_tracing_crop_data.pkl", + "catenanes_nodestats_data.pkl", + "catenanes_nodestats_branch_images.pkl", + "catenane", # filename + "catenanes_ordered_tracing_data.pkl", + "catenanes_ordered_tracing_grainstats.csv", + "catenanes_ordered_tracing_molstats.csv", + "catenanes_ordered_tracing_full_images.pkl", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "rep_int_disordered_tracing_crop_data.pkl", + "rep_int_nodestats_data_no_pair_odd_branches.pkl", + "rep_int_nodestats_branch_images_no_pair_odd_branches.pkl", + "replication_intermediate", # filename + "rep_int_ordered_tracing_data.pkl", + "rep_int_ordered_tracing_grainstats.csv", + "rep_int_ordered_tracing_molstats.csv", + "rep_int_ordered_tracing_full_images.pkl", + id="replication_intermediate", + ), + ], +) +def test_ordered_tracing_image( + image_filename: str, + disordered_tracing_direction_data_filename: str, + nodestats_data_filename: str, + nodestats_branch_images_filename: str, + filename: str, + expected_ordered_tracing_data_filename: str, + expected_ordered_tracing_grainstats_filename: str, + expected_molstats_filename: str, + expected_ordered_tracing_full_images_filename: str, +) -> None: + """Test the ordered tracing image method of ordered tracing.""" + # disordered_tracing_direction_data is the disordered tracing data + # for a particular threshold direction. + + # nodestats_direction_data contains both nodestats_data and nodestats_branch_images + + # Load the required data + image = np.load(GENERAL_RESOURCES / image_filename) + + with Path.open(DISORDERED_TRACING_RESOURCES / disordered_tracing_direction_data_filename, "rb") as f: + disordered_tracing_direction_data = pickle.load(f) + + with Path.open(NODESTATS_RESOURCES / nodestats_data_filename, "rb") as f: + nodestats_data = pickle.load(f) + + with Path.open(NODESTATS_RESOURCES / nodestats_branch_images_filename, "rb") as f: + nodestats_branch_images = pickle.load(f) + + nodestats_whole_data = {"stats": nodestats_data, "images": nodestats_branch_images} + + ( + result_ordered_tracing_data, + result_ordered_tracing_grainstats, + result_molstats_df, + result_ordered_tracing_full_images, + ) = ordered_tracing_image( + image=image, + disordered_tracing_direction_data=disordered_tracing_direction_data, + nodestats_direction_data=nodestats_whole_data, + filename=filename, + ordering_method="nodestats", + pad_width=1, + ) + + # # Debugging - grab variables to show images + # variable_ordered_traces = result_ordered_tracing_full_images["ordered_traces"] + # variable_all_molecules = result_ordered_tracing_full_images["all_molecules"] + # variable_over_under = result_ordered_tracing_full_images["over_under"] + # variable_trace_segments = result_ordered_tracing_full_images["trace_segments"] + + # # Save result ordered tracing data as pickle + # with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_data_filename, "wb") as f: + # pickle.dump(result_ordered_tracing_data, f) + + # # Save result grainstats additions as csv + # result_ordered_tracing_grainstats.to_csv(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_grainstats_filename) + + # # Save the molstats dataframe as csv + # result_molstats_df.to_csv(ORDERED_TRACING_RESOURCES / expected_molstats_filename) + + # # Save result ordered tracing full images as pickle + # with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_full_images_filename, "wb") as f: + # pickle.dump(result_ordered_tracing_full_images, f) + + # Load the expected results + with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_data_filename, "rb") as f: + expected_ordered_tracing_data = pickle.load(f) + + expected_ordered_tracing_grainstats = pd.read_csv( + ORDERED_TRACING_RESOURCES / expected_ordered_tracing_grainstats_filename, index_col=0 + ) + + expected_molstats_df = pd.read_csv(ORDERED_TRACING_RESOURCES / expected_molstats_filename, index_col=0) + + with Path.open(ORDERED_TRACING_RESOURCES / expected_ordered_tracing_full_images_filename, "rb") as f: + expected_ordered_tracing_full_images = pickle.load(f) + + # Check the results + np.testing.assert_equal(result_ordered_tracing_data, expected_ordered_tracing_data) + pd.testing.assert_frame_equal(result_ordered_tracing_grainstats, expected_ordered_tracing_grainstats) + pd.testing.assert_frame_equal(result_molstats_df, expected_molstats_df) + np.testing.assert_equal(result_ordered_tracing_full_images, expected_ordered_tracing_full_images) diff --git a/tests/tracing/test_splining.py b/tests/tracing/test_splining.py index 9ebcd320f6..7723d703fe 100644 --- a/tests/tracing/test_splining.py +++ b/tests/tracing/test_splining.py @@ -1,19 +1,208 @@ +# Disable ruff 301 - pickle loading is unsafe, but we don't care for tests +# ruff: noqa: S301 """Test the splining module.""" +import pickle +from pathlib import Path + +import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pandas as pd import pytest -from topostats.tracing.splining import windowTrace +from topostats.tracing.splining import splining_image, windowTrace -# pylint: disable=too-many-arguments +BASE_DIR = Path.cwd() +GENERAL_RESOURCES = BASE_DIR / "tests" / "resources" +SPLINING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "splining" +ORDERED_TRACING_RESOURCES = BASE_DIR / "tests" / "resources" / "tracing" / "ordered_tracing" + +# pylint: disable=unspecified-encoding # pylint: disable=too-many-locals +# pylint: disable=too-many-arguments PIXEL_TRACE = np.array( [[0, 0], [0, 1], [0, 2], [0, 3], [1, 3], [2, 3], [3, 3], [3, 2], [3, 1], [3, 0], [2, 0], [1, 0]] ).astype(np.int32) +def plot_spline_debugging( + image: npt.NDArray[np.float32], + result_all_splines_data: dict, + pixel_to_nm_scaling: float, +) -> None: + """ + Plot splines of an image overlaid on the image. + + Used for debugging changes to the splining code & visually ensuring the splines are correct. + + Parameters + ---------- + image : npt.NDArray[np.float32] + Image to plot the splines on. + result_all_splines_data : dict + Dictionary containing the spline coordinates for each molecule. + pixel_to_nm_scaling : float + Pixel to nm scaling factor. + """ + _, ax = plt.subplots(figsize=(10, 10)) + ax.imshow(image, cmap="gray") + # Array of lots of matplotlib colours + lots_of_colours = [ + "blue", + "green", + "red", + "cyan", + "magenta", + "yellow", + "black", + "white", + "orange", + "purple", + ] + for grain_key_index, grain_key in enumerate(result_all_splines_data.keys()): + print(f"Grain key: {grain_key}") + for mol_key_index, mol_key in enumerate(result_all_splines_data[grain_key].keys()): + spline_coords: npt.NDArray[np.float32] = result_all_splines_data[grain_key][mol_key]["spline_coords"] + bbox = result_all_splines_data[grain_key][mol_key]["bbox"] + bbox_min_col = bbox[0] + bbox_min_row = bbox[1] + previous_point = spline_coords[0] + colour = lots_of_colours[mol_key_index + grain_key_index * 3 % len(lots_of_colours)] + for point in spline_coords[1:]: + ax.plot( + [ + previous_point[1] / pixel_to_nm_scaling + bbox_min_row, + point[1] / pixel_to_nm_scaling + bbox_min_row, + ], + [ + previous_point[0] / pixel_to_nm_scaling + bbox_min_col, + point[0] / pixel_to_nm_scaling + bbox_min_col, + ], + color=colour, + linewidth=2, + ) + previous_point = point + plt.show() + + +@pytest.mark.parametrize( + ( + "image_filename", + "ordered_tracing_direction_data_filename", + "pixel_to_nm_scaling", + "splining_method", + "rolling_window_size", + "spline_step_size", + "spline_linear_smoothing", + "spline_circular_smoothing", + "spline_degree", + "filename", + "expected_all_splines_data_filename", + "expected_splining_grainstats_filename", + "expected_molstats_filename", + ), + [ + pytest.param( + "example_catenanes.npy", + "catenanes_ordered_tracing_data.pkl", + 1.0, # pixel_to_nm_scaling + # Splining parameters + "rolling_window", # splining_method + 20e-9, # rolling_window_size + 7.0e-9, # spline_step_size + 5.0, # spline_linear_smoothing + 5.0, # spline_circular_smoothing + 3, # spline_degree + "catenane", # filename + "catenanes_splining_data.pkl", + "catenanes_splining_grainstats.csv", + "catenanes_splining_molstats.csv", + id="catenane", + ), + pytest.param( + "example_rep_int.npy", + "rep_int_ordered_tracing_data.pkl", + 1.0, # pixel_to_nm_scaling + # Splining parameters + "rolling_window", # splining_method + 20e-9, # rolling_window_size + 7.0e-9, # spline_step_size + 5.0, # spline_linear_smoothing + 5.0, # spline_circular_smoothing + 3, # spline_degree + "replication_intermediate", # filename + "rep_int_splining_data.pkl", + "rep_int_splining_grainstats.csv", + "rep_int_splining_molstats.csv", + id="replication_intermediate", + ), + ], +) +def test_splining_image( + image_filename: str, + ordered_tracing_direction_data_filename: str, + pixel_to_nm_scaling: float, + splining_method: str, + rolling_window_size: float, + spline_step_size: float, + spline_linear_smoothing: float, + spline_circular_smoothing: float, + spline_degree: int, + filename: str, + expected_all_splines_data_filename: str, + expected_splining_grainstats_filename: str, + expected_molstats_filename: str, +) -> None: + """Test the splining_image function of the splining module.""" + # Load the data + image = np.load(GENERAL_RESOURCES / image_filename) + + # Load the ordered tracing direction data + with Path.open(ORDERED_TRACING_RESOURCES / ordered_tracing_direction_data_filename, "rb") as file: + ordered_tracing_direction_data = pickle.load(file) + + result_all_splines_data, result_splining_grainstats, result_molstats_df = splining_image( + image=image, + ordered_tracing_direction_data=ordered_tracing_direction_data, + pixel_to_nm_scaling=pixel_to_nm_scaling, + filename=filename, + method=splining_method, + rolling_window_size=rolling_window_size, + spline_step_size=spline_step_size, + spline_linear_smoothing=spline_linear_smoothing, + spline_circular_smoothing=spline_circular_smoothing, + spline_degree=spline_degree, + ) + + # When updating the test, you will want to verify that the splines are correct. Use + # plot_spline_debugging to plot the splines on the image. + + # # Save the results to update the test data + # # Save result splining data as pickle + # with Path.open(SPLINING_RESOURCES / expected_all_splines_data_filename, "wb") as file: + # pickle.dump(result_all_splines_data, file) + + # # Save result grainstats additions as csv + # result_splining_grainstats.to_csv(SPLINING_RESOURCES / expected_splining_grainstats_filename) + + # # Save result molstats as csv + # result_molstats_df.to_csv(SPLINING_RESOURCES / expected_molstats_filename) + + # Load the expected results + with Path.open(SPLINING_RESOURCES / expected_all_splines_data_filename, "rb") as file: + expected_all_splines_data = pickle.load(file) + + expected_splining_grainstats = pd.read_csv(SPLINING_RESOURCES / expected_splining_grainstats_filename, index_col=0) + expected_molstats_df = pd.read_csv(SPLINING_RESOURCES / expected_molstats_filename, index_col=0) + + # Check the results + np.testing.assert_equal(result_all_splines_data, expected_all_splines_data) + pd.testing.assert_frame_equal(result_splining_grainstats, expected_splining_grainstats) + pd.testing.assert_frame_equal(result_molstats_df, expected_molstats_df) + + @pytest.mark.parametrize( ("pixel_trace", "rolling_window_size", "pixel_to_nm_scaling", "expected_pooled_trace"), [ diff --git a/topostats/default_config.yaml b/topostats/default_config.yaml index e6ce12d07c..95832b0227 100644 --- a/topostats/default_config.yaml +++ b/topostats/default_config.yaml @@ -65,7 +65,7 @@ disordered_tracing: height_bias: 0.6 # Percentage of lowest pixels to remove each skeletonisation iteration. 1 equates to zhang. pruning_params: method: topostats # Method to clean branches of the skeleton. Options : topostats - max_length: -1 # Maximum length in nm to remove a branch containing an endpoint. '-1' is 15% of total trace length (in pixels). + max_length: 10.0 # Maximum length in nm to remove a branch containing an endpoint. height_threshold: # The height to remove branches below. method_values: mid # The method to obtain a branch's height for pruning. Options : min | median | mid. method_outlier: mean_abs # The method to prune branches based on height. Options : abs | mean_abs | iqr. diff --git a/topostats/io.py b/topostats/io.py index 816cc1be70..5d252a5330 100644 --- a/topostats/io.py +++ b/topostats/io.py @@ -34,6 +34,54 @@ # pylint: disable=too-many-lines +# Sylvia: Ruff says too complex but I think breaking this out would be more complex. +def dict_almost_equal(dict1: dict, dict2: dict, abs_tol: float = 1e-9): # noqa: C901 + """ + Recursively check if two dictionaries are almost equal with a given absolute tolerance. + + Parameters + ---------- + dict1 : dict + First dictionary to compare. + dict2 : dict + Second dictionary to compare. + abs_tol : float + Absolute tolerance to check for equality. + + Returns + ------- + bool + True if the dictionaries are almost equal, False otherwise. + """ + if dict1.keys() != dict2.keys(): + return False + + LOGGER.info("Comparing dictionaries") + + for key in dict1: + LOGGER.info(f"Comparing key {key}") + if isinstance(dict1[key], dict) and isinstance(dict2[key], dict): + if not dict_almost_equal(dict1[key], dict2[key], abs_tol=abs_tol): + return False + elif isinstance(dict1[key], np.ndarray) and isinstance(dict2[key], np.ndarray): + if not np.allclose(dict1[key], dict2[key], atol=abs_tol): + LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") + return False + elif isinstance(dict1[key], float) and isinstance(dict2[key], float): + # Skip if both values are NaN + if not (np.isnan(dict1[key]) and np.isnan(dict2[key])): + # Check if both values are close + if not np.isclose(dict1[key], dict2[key], atol=abs_tol): + LOGGER.info(f"Key {key} type: {type(dict1[key])} not equal: {dict1[key]} != {dict2[key]}") + return False + + elif dict1[key] != dict2[key]: + LOGGER.info(f"Key {key} not equal: {dict1[key]} != {dict2[key]}") + return False + + return True + + def read_yaml(filename: str | Path) -> dict: """ Read a YAML file. diff --git a/topostats/plotting_dictionary.yaml b/topostats/plotting_dictionary.yaml index b258de99e0..67ae1f78c1 100644 --- a/topostats/plotting_dictionary.yaml +++ b/topostats/plotting_dictionary.yaml @@ -125,30 +125,35 @@ mask_grains: filename: "17-mask_grains" title: "Mask for Grains" image_type: "binary" + mask_cmap: "binary" savefig_dpi: 100 core_set: false labelled_regions_01: filename: "18-labelled_regions" title: "Labelled Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false tidied_border: filename: "19-tidy_borders" title: "Tidied Borders" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false removed_noise: filename: "20-noise_removed" title: "Noise removed" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false removed_small_objects: filename: "21-small_objects_removed" title: "Small Objects Removed" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false mask_overlay: @@ -160,12 +165,14 @@ labelled_regions_02: filename: "22-labelled_regions" title: "Labelled Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false coloured_regions: filename: "23-coloured_regions" title: "Coloured Regions" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false bounding_boxes: @@ -178,6 +185,7 @@ coloured_boxes: filename: "25-labelled_image_bboxes" title: "Labelled Image with Bounding Boxes" image_type: "binary" + mask_cmap: "rainbow" savefig_dpi: 100 core_set: false grain_image: @@ -218,23 +226,30 @@ pruned_skeleton: mask_cmap: "blue" core_set: false savefig_dpi: 600 +branch_indexes: + filename: "23-segment_indexes" + title: "Skeleton Segment Indexes" + image_type: "non-binary" + mask_cmap: "viridis" + core_set: false + savefig_dpi: 600 branch_types: - filename: "23-branch_types" - title: "Skeleton Branch Types" + filename: "24-segment_types" + title: "Skeleton Segment Types" image_type: "non-binary" mask_cmap: "viridis" core_set: false savefig_dpi: 600 # Nodestats troubleshooting images convolved_skeletons: - filename: "24-convolved_skeleton" + filename: "25-convolved_skeleton" title: "Skeletons and Nodes" image_type: "non-binary" mask_cmap: "blue_purple_green" core_set: false savefig_dpi: 600 node_centres: - filename: "25-node_centres" + filename: "26-node_centres" title: "Skeletons and Highlighted Nodes" image_type: "non-binary" mask_cmap: "blue_purple_green" @@ -276,21 +291,21 @@ ordered_traces: core_set: false savefig_dpi: 600 trace_segments: - filename: "26-trace_segments" + filename: "27-trace_segments" title: "Trace Segments" image_type: "non-binary" mask_cmap: "gist_rainbow" savefig_dpi: 600 core_set: false over_under: - filename: "27-molecule_crossings" + filename: "28-molecule_crossings" title: "Visualised Molecule Crossings" image_type: "non-binary" mask_cmap: "blue_purple_green" core_set: false savefig_dpi: 600 all_molecules: - filename: "28-all_molecules" + filename: "29-all_molecules" title: "Individual Molecules" image_type: "non-binary" mask_cmap: "blue_purple_green" @@ -298,7 +313,7 @@ all_molecules: core_set: false # Splining fitted_trace: - filename: "27-fitted-traces" + filename: "30-fitted-traces" title: "Fitted Trace" image_type: "non-binary" mask_cmap: "blue_purple_green" diff --git a/topostats/plottingfuncs.py b/topostats/plottingfuncs.py index c3da4f8517..807ba59db6 100644 --- a/topostats/plottingfuncs.py +++ b/topostats/plottingfuncs.py @@ -305,7 +305,7 @@ def plot_and_save(self): # Only plot if image_set is "all" (i.e. user wants all images) or an image is in the core_set if self.image_set == "all" or self.core_set: fig, ax = self.save_figure() - LOGGER.info( + LOGGER.debug( f"[{self.filename}] : Image saved to : {str(self.output_dir / self.filename)}.{self.savefig_format}" f" | DPI: {self.savefig_dpi}" ) diff --git a/topostats/processing.py b/topostats/processing.py index 8458d623be..e4300005fe 100644 --- a/topostats/processing.py +++ b/topostats/processing.py @@ -33,6 +33,7 @@ # pylint: disable=too-many-statements # pylint: disable=too-many-nested-blocks # pylint: disable=unnecessary-dict-index-lookup +# pylint: disable=too-many-lines LOGGER = setup_logger(LOGGER_NAME) @@ -199,7 +200,9 @@ def run_grains( # noqa: C901 for plot_name, array in image_arrays.items(): LOGGER.info(f"[{filename}] : Plotting {plot_name} image") plotting_config["plot_dict"][plot_name]["output_dir"] = grain_out_path_direction - Images(array, **plotting_config["plot_dict"][plot_name]).plot_and_save() + Images( + data=np.zeros_like(array), masked_array=array, **plotting_config["plot_dict"][plot_name] + ).plot_and_save() # Make a plot of coloured regions with bounding boxes plotting_config["plot_dict"]["bounding_boxes"]["output_dir"] = grain_out_path_direction Images( @@ -209,7 +212,8 @@ def run_grains( # noqa: C901 ).plot_and_save() plotting_config["plot_dict"]["coloured_boxes"]["output_dir"] = grain_out_path_direction Images( - grains.directions[direction]["labelled_regions_02"], + data=np.zeros_like(grains.directions[direction]["labelled_regions_02"]), + masked_array=grains.directions[direction]["labelled_regions_02"], **plotting_config["plot_dict"]["coloured_boxes"], region_properties=grains.region_properties[direction], ).plot_and_save() @@ -219,7 +223,7 @@ def run_grains( # noqa: C901 Images( image, filename=f"{filename}_{direction}_masked", - masked_array=grains.directions[direction]["removed_small_objects"], + masked_array=grains.directions[direction]["removed_small_objects"].astype(bool), **plotting_config["plot_dict"][plot_name], ).plot_and_save() @@ -365,7 +369,7 @@ def run_disordered_trace( tracing_out_path: Path, disordered_tracing_config: dict, plotting_config: dict, - results_df: pd.DataFrame = None, + grainstats_df: pd.DataFrame = None, ) -> dict: """ Skeletonise and prune grains, adding results to statistics data frames and optionally plot results. @@ -391,8 +395,8 @@ def run_disordered_trace( Dictionary configuration for obtaining a disordered trace representation of the grains. plotting_config : dict Dictionary configuration for plotting images. - results_df : pd.DataFrame, optional - The grainstats dataframe to bee added to. by default None. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. Returns ------- @@ -403,7 +407,7 @@ def run_disordered_trace( disordered_tracing_config.pop("run") LOGGER.info(f"[{filename}] : *** Disordered Tracing ***") disordered_traces = defaultdict() - grainstats_additions_image = pd.DataFrame() + disordered_trace_grainstats = pd.DataFrame() disordered_tracing_stats_image = pd.DataFrame() try: # run image using directional grain masks @@ -418,7 +422,7 @@ def run_disordered_trace( # if grains are found ( disordered_traces_cropped_data, - grainstats_additions_df, + _disordered_trace_grainstats, disordered_tracing_images, disordered_tracing_stats, ) = trace_image_disordered( @@ -429,8 +433,8 @@ def run_disordered_trace( **disordered_tracing_config, ) # save per image new grainstats stats - grainstats_additions_df["threshold"] = direction - grainstats_additions_image = pd.concat([grainstats_additions_image, grainstats_additions_df]) + _disordered_trace_grainstats["threshold"] = direction + disordered_trace_grainstats = pd.concat([disordered_trace_grainstats, _disordered_trace_grainstats]) disordered_tracing_stats["threshold"] = direction disordered_tracing_stats["basename"] = basename.parent @@ -456,20 +460,23 @@ def run_disordered_trace( # merge grainstats data with other dataframe resultant_grainstats = ( - pd.merge(results_df, grainstats_additions_image, on=["image", "threshold", "grain_number"]) - if results_df is not None - else grainstats_additions_image + pd.merge(grainstats_df, disordered_trace_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else disordered_trace_grainstats ) return disordered_traces, resultant_grainstats, disordered_tracing_stats_image - except Exception: - LOGGER.info("Disordered tracing failed - skipping.") - return {}, results_df, None + except Exception as e: + LOGGER.info( + f"[{filename}] : Disordered tracing failed - skipping. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + return {}, grainstats_df, None else: LOGGER.info(f"[{filename}] Calculation of Disordered Tracing disabled, returning empty dictionary.") - return {}, results_df, None + return {}, grainstats_df, None def run_nodestats( # noqa: C901 @@ -481,7 +488,7 @@ def run_nodestats( # noqa: C901 tracing_out_path: Path, nodestats_config: dict, plotting_config: dict, - results_df: pd.DataFrame = None, + grainstats_df: pd.DataFrame = None, ) -> tuple[dict, pd.DataFrame]: """ Analyse crossing points in grains adding results to statistics data frames and optionally plot results. @@ -504,8 +511,8 @@ def run_nodestats( # noqa: C901 Dictionary configuration for analysing the crossing points. plotting_config : dict Dictionary configuration for plotting images. - results_df : pd.DataFrame, optional - The grainstats dataframe to bee added to. by default None. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to bee added to. by default None. Returns ------- @@ -516,13 +523,13 @@ def run_nodestats( # noqa: C901 nodestats_config.pop("run") LOGGER.info(f"[{filename}] : *** Nodestats ***") nodestats_whole_data = defaultdict() - grainstats_additions_image = pd.DataFrame() + nodestats_grainstats = pd.DataFrame() try: # run image using directional grain masks for direction, disordered_tracing_direction_data in disordered_tracing_data.items(): ( nodestats_data, - grainstats_additions_df, + _nodestats_grainstats, nodestats_full_images, nodestats_branch_images, ) = nodestats_image( @@ -534,8 +541,8 @@ def run_nodestats( # noqa: C901 ) # save per image new grainstats stats - grainstats_additions_df["threshold"] = direction - grainstats_additions_image = pd.concat([grainstats_additions_image, grainstats_additions_df]) + _nodestats_grainstats["threshold"] = direction + nodestats_grainstats = pd.concat([nodestats_grainstats, _nodestats_grainstats]) # append direction results to dict nodestats_whole_data[direction] = {"stats": nodestats_data, "images": nodestats_branch_images} @@ -591,33 +598,37 @@ def run_nodestats( # noqa: C901 # merge grainstats data with other dataframe resultant_grainstats = ( - pd.merge(results_df, grainstats_additions_image, on=["image", "threshold", "grain_number"]) - if results_df is not None - else grainstats_additions_image + pd.merge(grainstats_df, nodestats_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else nodestats_grainstats ) # merge all image dictionaries return nodestats_whole_data, resultant_grainstats except Exception as e: - LOGGER.info(f"NodeStats failed with {e} - skipping.") - return nodestats_whole_data, grainstats_additions_image + LOGGER.info( + f"[{filename}] : NodeStats failed - skipping. Consider raising an issue on GitHub. Error: ", exc_info=e + ) + return nodestats_whole_data, nodestats_grainstats else: LOGGER.info(f"[{filename}] : Calculation of nodestats disabled, returning empty dataframe.") - return None, results_df + return None, grainstats_df +# need to add in the molstats here def run_ordered_tracing( image: npt.NDArray, disordered_tracing_data: dict, nodestats_data: dict, filename: str, + basename: Path, core_out_path: Path, tracing_out_path: Path, ordered_tracing_config: dict, plotting_config: dict, - results_df: pd.DataFrame = None, + grainstats_df: pd.DataFrame = None, ) -> tuple: """ Order coordinates of traces, adding results to statistics data frames and optionally plot results. @@ -632,6 +643,8 @@ def run_ordered_tracing( Dictionary of images and statistics from the NodeStats analysis. Result from "run_nodestats". filename : str Name of the image. + basename : Path + The path of the files' parent directory. core_out_path : Path Path to save the core ordered tracing image to. tracing_out_path : Path @@ -640,8 +653,8 @@ def run_ordered_tracing( Dictionary configuration for obtaining an ordered trace representation of the skeletons. plotting_config : dict Dictionary configuration for plotting images. - results_df : pd.DataFrame, optional - The grainstats dataframe to bee added to. by default None. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. Returns ------- @@ -652,7 +665,8 @@ def run_ordered_tracing( ordered_tracing_config.pop("run") LOGGER.info(f"[{filename}] : *** Ordered Tracing ***") ordered_tracing_image_data = defaultdict() - grainstats_additions_image = pd.DataFrame() + ordered_tracing_molstats = pd.DataFrame() + ordered_tracing_grainstats = pd.DataFrame() try: # run image using directional grain masks @@ -667,7 +681,8 @@ def run_ordered_tracing( # if grains are found ( ordered_tracing_data, - grainstats_additions_df, + _ordered_tracing_grainstats, + _ordered_tracing_molstats, ordered_tracing_full_images, ) = ordered_tracing_image( image=image, @@ -678,13 +693,16 @@ def run_ordered_tracing( ) # save per image new grainstats stats - grainstats_additions_df["threshold"] = direction - grainstats_additions_image = pd.concat([grainstats_additions_image, grainstats_additions_df]) + _ordered_tracing_grainstats["threshold"] = direction + ordered_tracing_grainstats = pd.concat([ordered_tracing_grainstats, _ordered_tracing_grainstats]) + _ordered_tracing_molstats["threshold"] = direction + ordered_tracing_molstats = pd.concat([ordered_tracing_molstats, _ordered_tracing_molstats]) # append direction results to dict ordered_tracing_image_data[direction] = ordered_tracing_data # save whole image plots + plotting_config["plot_dict"]["ordered_traces"]["core_set"] = True # fudge around core having own cmap Images( filename=f"{filename}_{direction}_ordered_traces", data=image, @@ -705,19 +723,24 @@ def run_ordered_tracing( # merge grainstats data with other dataframe resultant_grainstats = ( - pd.merge(results_df, grainstats_additions_image, on=["image", "threshold", "grain_number"]) - if results_df is not None - else grainstats_additions_image + pd.merge(grainstats_df, ordered_tracing_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else ordered_tracing_grainstats ) + ordered_tracing_molstats["basename"] = basename.parent + # merge all image dictionaries - return ordered_tracing_image_data, resultant_grainstats + return ordered_tracing_image_data, resultant_grainstats, ordered_tracing_molstats except Exception as e: - LOGGER.info(f"Ordered Tracing failed with {e} - skipping.") - return ordered_tracing_image_data, results_df + LOGGER.info( + f"[{filename}] : Ordered Tracing failed - skipping. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) + return ordered_tracing_image_data, grainstats_df, None - return None, results_df + return None, grainstats_df, None def run_splining( @@ -725,18 +748,18 @@ def run_splining( ordered_tracing_data: dict, pixel_to_nm_scaling: float, filename: str, - basename: Path, core_out_path: Path, splining_config: dict, plotting_config: dict, - results_df: pd.DataFrame = None, + grainstats_df: pd.DataFrame = None, + molstats_df: pd.DataFrame = None, ) -> tuple: """ Smooth the ordered trace coordinates, adding results to statistics data frames and optionally plot results. Parameters ---------- - image : npt.ndarray + image : npt.NDArray Image containing the DNA to pass to the tracing function. ordered_tracing_data : dict Dictionary of ordered coordinates. Result from "run_ordered_tracing". @@ -744,16 +767,16 @@ def run_splining( Scaling factor for converting pixel length scales to nanometers, i.e. the number of pixels per nanometres (nm). filename : str Name of the image. - basename : Path - Path to directory containing the image. core_out_path : Path Path to save the core ordered tracing image to. splining_config : dict Dictionary configuration for obtaining an ordered trace representation of the skeletons. plotting_config : dict Dictionary configuration for plotting images. - results_df : pd.DataFrame, optional - The grainstats dataframe to bee added to. by default None. + grainstats_df : pd.DataFrame, optional + The grain statistics dataframe to be added to. by default None. + molstats_df : pd.DataFrame, optional + The molecule statistics dataframe to be added to. by default None. Returns ------- @@ -764,8 +787,8 @@ def run_splining( splining_config.pop("run") LOGGER.info(f"[{filename}] : *** Splining ***") splined_image_data = defaultdict() - splining_stats = pd.DataFrame() - image_molstats_df = pd.DataFrame() + splining_grainstats = pd.DataFrame() + splining_molstats = pd.DataFrame() try: # run image using directional grain masks @@ -774,15 +797,15 @@ def run_splining( LOGGER.warning( f"[{filename}] : No grains exist for the {direction} direction. Skipping disordered_tracing for {direction}." ) - splining_stats = create_empty_dataframe() - image_molstats_df = create_empty_dataframe(columns=["image", "basename", "threshold"]) + splining_grainstats = create_empty_dataframe() + splining_molstats = create_empty_dataframe(columns=["image", "basename", "threshold"]) raise ValueError(f"No grains exist for the {direction} direction") # if grains are found ( splined_data, - _splining_stats, - molstats_df, + _splining_grainstats, + _splining_molstats, ) = splining_image( image=image, ordered_tracing_direction_data=ordered_tracing_direction_data, @@ -792,10 +815,10 @@ def run_splining( ) # save per image new grainstats stats - _splining_stats["threshold"] = direction - splining_stats = pd.concat([splining_stats, _splining_stats]) - molstats_df["threshold"] = direction - image_molstats_df = pd.concat([image_molstats_df, molstats_df]) + _splining_grainstats["threshold"] = direction + splining_grainstats = pd.concat([splining_grainstats, _splining_grainstats]) + _splining_molstats["threshold"] = direction + splining_molstats = pd.concat([splining_molstats, _splining_molstats]) # append direction results to dict splined_image_data[direction] = splined_data @@ -816,20 +839,27 @@ def run_splining( # merge grainstats data with other dataframe resultant_grainstats = ( - pd.merge(results_df, splining_stats, on=["image", "threshold", "grain_number"]) - if results_df is not None - else splining_stats + pd.merge(grainstats_df, splining_grainstats, on=["image", "threshold", "grain_number"]) + if grainstats_df is not None + else splining_grainstats + ) + # merge molstats data with other dataframe + resultant_molstats = ( + pd.merge(molstats_df, splining_molstats, on=["image", "threshold", "grain_number", "molecule_number"]) + if molstats_df is not None + else splining_molstats ) - image_molstats_df["basename"] = basename.parent # merge all image dictionaries - return splined_image_data, resultant_grainstats, image_molstats_df + return splined_image_data, resultant_grainstats, resultant_molstats except Exception as e: - LOGGER.info(f"Splining failed with {e} - skipping.") - return splined_image_data, splining_stats, image_molstats_df + LOGGER.error( + f"[{filename}] : Splining failed - skipping. Consider raising an issue on GitHub. Error: ", exc_info=e + ) + return splined_image_data, splining_grainstats, splining_molstats - return None, results_df, create_empty_dataframe(columns=["image", "basename", "threshold"]) + return None, grainstats_df, molstats_df def get_out_paths(image_path: Path, base_dir: Path, output_dir: Path, filename: str, plotting_config: dict): @@ -892,7 +922,7 @@ def process_scan( Parameters ---------- topostats_object : dict[str, Union[npt.NDArray, Path, float]] - A dictionary with keys 'image', 'img_path' and 'px_2_nm' containing a file or frames' image, it's path and it's + A dictionary with keys 'image', 'img_path' and 'pixel_to_nm_scaling' containing a file or frames' image, it's path and it's pixel to namometre scaling value. base_dir : str | Path Directory to recursively search for files, if not specified the current directory is scanned. @@ -963,7 +993,7 @@ def process_scan( if "above" in topostats_object["grain_masks"].keys() or "below" in topostats_object["grain_masks"].keys(): # Grainstats : - results_df = run_grainstats( + grainstats_df = run_grainstats( image=topostats_object["image_flattened"], pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], grain_masks=topostats_object["grain_masks"], @@ -975,7 +1005,7 @@ def process_scan( ) # Disordered Tracing - disordered_traces_data, results_df, disordered_tracing_stats = run_disordered_trace( + disordered_traces_data, grainstats_df, disordered_tracing_stats = run_disordered_trace( image=topostats_object["image_flattened"], grain_masks=topostats_object["grain_masks"], pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], @@ -984,13 +1014,13 @@ def process_scan( core_out_path=core_out_path, tracing_out_path=tracing_out_path, disordered_tracing_config=disordered_tracing_config, - results_df=results_df, + grainstats_df=grainstats_df, plotting_config=plotting_config, ) topostats_object["disordered_traces"] = disordered_traces_data # Nodestats - nodestats, results_df = run_nodestats( + nodestats, grainstats_df = run_nodestats( image=topostats_object["image_flattened"], disordered_tracing_data=topostats_object["disordered_traces"], pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], @@ -999,42 +1029,43 @@ def process_scan( tracing_out_path=tracing_out_path, plotting_config=plotting_config, nodestats_config=nodestats_config, - results_df=results_df, + grainstats_df=grainstats_df, ) - topostats_object["nodestats"] = nodestats # Ordered Tracing - ordered_tracing, results_df = run_ordered_tracing( + ordered_tracing, grainstats_df, molstats_df = run_ordered_tracing( image=topostats_object["image_flattened"], disordered_tracing_data=topostats_object["disordered_traces"], nodestats_data=nodestats, filename=topostats_object["filename"], + basename=topostats_object["img_path"], core_out_path=core_out_path, tracing_out_path=tracing_out_path, ordered_tracing_config=ordered_tracing_config, plotting_config=plotting_config, - results_df=results_df, + grainstats_df=grainstats_df, ) topostats_object["ordered_traces"] = ordered_tracing + topostats_object["nodestats"] = nodestats # looks weird but ordered adds an extra field # splining - splined_data, results_df, molstats_df = run_splining( + splined_data, grainstats_df, molstats_df = run_splining( image=topostats_object["image_flattened"], ordered_tracing_data=topostats_object["ordered_traces"], pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], filename=topostats_object["filename"], - basename=topostats_object["img_path"], core_out_path=core_out_path, plotting_config=plotting_config, splining_config=splining_config, - results_df=results_df, + grainstats_df=grainstats_df, + molstats_df=molstats_df, ) # Add grain trace data to topostats object topostats_object["splining"] = splined_data else: - results_df = create_empty_dataframe() + grainstats_df = create_empty_dataframe() molstats_df = create_empty_dataframe() disordered_tracing_stats = create_empty_dataframe() @@ -1049,7 +1080,7 @@ def process_scan( image_stats = image_statistics( image=image_for_image_stats, filename=topostats_object["filename"], - results_df=results_df, + results_df=grainstats_df, pixel_to_nm_scaling=topostats_object["pixel_to_nm_scaling"], ) @@ -1058,7 +1089,7 @@ def process_scan( output_dir=core_out_path, filename=str(topostats_object["filename"]), topostats_object=topostats_object ) - return topostats_object["img_path"], results_df, image_stats, disordered_tracing_stats, molstats_df + return topostats_object["img_path"], grainstats_df, image_stats, disordered_tracing_stats, molstats_df def check_run_steps( # noqa: C901 @@ -1068,7 +1099,7 @@ def check_run_steps( # noqa: C901 disordered_tracing_run: bool, nodestats_run: bool, splining_run: bool, -) -> None: # noqa: C901 +) -> None: """ Check options for running steps (Filter, Grain, Grainstats and DNA tracing) are logically consistent. diff --git a/topostats/tracing/disordered_tracing.py b/topostats/tracing/disordered_tracing.py index 61f792823d..a29f3cce4c 100644 --- a/topostats/tracing/disordered_tracing.py +++ b/topostats/tracing/disordered_tracing.py @@ -120,7 +120,6 @@ def __init__( # pylint: disable=too-many-arguments def trace_dna(self): """Perform the DNA skeletonisation and cleaning pipeline.""" - # LOGGER.info(f"[{self.filename}] : mask_smooth_params : {self.mask_smoothing_params=}") self.smoothed_mask = self.smooth_mask(self.mask, **self.mask_smoothing_params) self.skeleton = getSkeleton( self.image, @@ -128,13 +127,16 @@ def trace_dna(self): method=self.skeletonisation_params["method"], height_bias=self.skeletonisation_params["height_bias"], ).get_skeleton() - self.pruned_skeleton = prune_skeleton(self.image, self.skeleton, **self.pruning_params.copy()) + self.pruned_skeleton = prune_skeleton( + self.image, self.skeleton, self.pixel_to_nm_scaling, **self.pruning_params.copy() + ) self.pruned_skeleton = self.remove_touching_edge(self.pruned_skeleton) self.disordered_trace = np.argwhere(self.pruned_skeleton == 1) if self.disordered_trace is None: - LOGGER.info(f"[{self.filename}] : Grain failed to Skeletonise") + LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} failed to Skeletonise.") elif len(self.disordered_trace) < self.min_skeleton_size: + LOGGER.warning(f"[{self.filename}] : Grain {self.n_grain} skeleton < {self.min_skeleton_size}, skipping.") self.disordered_trace = None def re_add_holes( @@ -248,9 +250,9 @@ def smooth_mask( gauss = gauss.astype(np.int32) # Add hole to the smooth mask conditional on smallest pixel difference for dilation or the Gaussian smoothing. if dilation.sum() > gauss.sum(): - LOGGER.info(f"[{self.filename}] : smoothing done by gaussian {gaussian_sigma}") + LOGGER.debug(f"[{self.filename}] : smoothing done by gaussian {gaussian_sigma}") return self.re_add_holes(grain, gauss, holearea_min_max) - LOGGER.info(f"[{self.filename}] : smoothing done by dilation {dilation_iterations}") + LOGGER.debug(f"[{self.filename}] : smoothing done by dilation {dilation_iterations}") return self.re_add_holes(grain, dilation, holearea_min_max) @@ -313,10 +315,11 @@ def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-local "smoothed_grain": img_base.copy(), "skeleton": img_base.copy(), "pruned_skeleton": img_base.copy(), + "branch_indexes": img_base.copy(), "branch_types": img_base.copy(), } - LOGGER.info(f"[{filename}] : Calculating Disordered Tracing statistics for {n_grains} grains.") + LOGGER.info(f"[{filename}] : Calculating Disordered Tracing statistics for {n_grains} grains...") for cropped_image_index, cropped_image in cropped_images.items(): try: @@ -332,7 +335,16 @@ def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-local min_skeleton_size=min_skeleton_size, n_grain=cropped_image_index, ) - LOGGER.info(f"[{filename}] : Disordered Traced grain {cropped_image_index + 1} of {n_grains}") + LOGGER.debug(f"[{filename}] : Disordered Traced grain {cropped_image_index + 1} of {n_grains}") + + # obtain segment stats + skan_skeleton = skan.Skeleton( + np.where(disordered_trace_images["pruned_skeleton"] == 1, cropped_image, 0), + spacing=pixel_to_nm_scaling, + ) + skan_df = skan.summarize(skan_skeleton) + skan_df = compile_skan_stats(skan_df, skan_skeleton, cropped_image, filename, cropped_image_index) + disordered_tracing_stats = pd.concat((disordered_tracing_stats, skan_df)) # obtain stats conv_pruned_skeleton = convolve_skeleton(disordered_trace_images["pruned_skeleton"]) @@ -341,31 +353,8 @@ def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-local "grain_number": cropped_image_index, "grain_endpoints": (conv_pruned_skeleton == 2).sum(), "grain_junctions": (conv_pruned_skeleton == 3).sum(), + "total_branch_lengths": skan_df["branch_distance"].sum(), } - # obtain segment stats - res = skan.summarize( - skan.Skeleton( - np.where(disordered_trace_images["pruned_skeleton"] == 1, cropped_image, 0), - spacing=pixel_to_nm_scaling, - ) - ) - res["image"] = filename - res["grain_number"] = cropped_image_index - disordered_tracing_stats = pd.concat( - ( - disordered_tracing_stats, - res[ - [ - "image", - "grain_number", - "branch-distance", - "branch-type", - "mean-pixel-value", - "stdev-pixel-value", - ] - ], - ) - ) # remap the cropped images back onto the original for image_name, full_image in all_images.items(): @@ -376,7 +365,11 @@ def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-local disordered_trace_crop_data[f"grain_{cropped_image_index}"]["bbox"] = bboxs[cropped_image_index] except Exception as e: # pylint: disable=broad-exception-caught - LOGGER.warning(f"[{filename}] : Disordered tracing of grain {cropped_image_index} failed with {e}.") + LOGGER.error( + f"[{filename}] : Disordered tracing of grain" + + f"{cropped_image_index} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) # convert stats dict to dataframe grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") @@ -384,6 +377,139 @@ def trace_image_disordered( # pylint: disable=too-many-arguments,too-many-local return disordered_trace_crop_data, grainstats_additions_df, all_images, disordered_tracing_stats +def compile_skan_stats( + skan_df: pd.DataFrame, skan_skeleton: skan.Skeleton, image: npt.NDArray, filename: str, grain_number: int +) -> pd.DataFrame: + """ + Obtain and add more stats to the resultant Skan dataframe. + + Parameters + ---------- + skan_df : pd.DataFrame + The statistics DataFrame produced by Skan's `summarize` function. + skan_skeleton : skan.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + filename : str + Name of the file being processed. + grain_number : int + The number of the grain being processed. + + Returns + ------- + pd.DataFrame + A dataframe containing the filename, grain_number, branch-distance, branch-type, connected_segments, + mean-pixel-value, stdev-pixel-value, min-value, median-value, and mid-value. + """ + skan_df["image"] = filename + skan_df["grain_number"] = grain_number + skan_df["connected_segments"] = skan_df.apply(find_connections, axis=1, skan_df=skan_df) + skan_df["min_value"] = skan_df.apply(lambda x: segment_heights(x, skan_skeleton, image).min(), axis=1) + skan_df["median_value"] = skan_df.apply(lambda x: np.median(segment_heights(x, skan_skeleton, image)), axis=1) + skan_df["middle_value"] = skan_df.apply(segment_middles, skan_skeleton=skan_skeleton, image=image, axis=1) + + skan_df = skan_df.rename( + columns={ # remove with Skan new release + "branch-distance": "branch_distance", + "branch-type": "branch_type", + "mean-pixel-value": "mean_pixel_value", + "stdev-pixel-value": "stdev_pixel_value", + } + ) + + # remove unused skan columns + return skan_df[ + [ + "image", + "grain_number", + "branch_distance", + "branch_type", + "connected_segments", + "mean_pixel_value", + "stdev_pixel_value", + "min_value", + "median_value", + "middle_value", + ] + ] + + +def segment_heights(row: pd.Series, skan_skeleton: skan.Skeleton, image: npt.NDArray) -> npt.NDArray: + """ + Obtain an ordered list of heights from the skan defined skeleton segment. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_skeleton : skan.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + + Returns + ------- + npt.NDArray + Heights along the segment, naturally ordered by Skan. + """ + coords = skan_skeleton.path_coordinates(row.name) + return image[coords[:, 0], coords[:, 1]] + + +def segment_middles(row: pd.Series, skan_skeleton: skan.csr.Skeleton, image: npt.NDArray) -> float: + """ + Obtain the pixel value in the middle of the ordered segment. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_skeleton : skan.csr.Skeleton + The graphical representation of the skeleton produced by Skan. + image : npt.NDArray + The image the skeleton was produced from. + + Returns + ------- + float + The single or mean pixel value corresponding to the middle coordinate(s) of the segment. + """ + heights = segment_heights(row, skan_skeleton, image) + middle_idx, middle_remainder = (len(heights) + 1) // 2 - 1, (len(heights) + 1) % 2 + return heights[[middle_idx, middle_idx + middle_remainder]].mean() + + +def find_connections(row: pd.Series, skan_df: pd.DataFrame) -> str: + """ + Compile the neighbouring branch indexes of the row. + + Parameters + ---------- + row : pd.Series + A row from the Skan summarize dataframe. + skan_df : pd.DataFrame + The statistics DataFrame produced by Skan's `summarize` function. + + Returns + ------- + str + A string representation of a list of matching row indices where the node src and dst + columns match that of the rows. + String is needed for csv compatibility since csvs can't hold lists. + """ + connections = skan_df[ + (skan_df["node-id-src"] == row["node-id-src"]) + | (skan_df["node-id-dst"] == row["node-id-dst"]) + | (skan_df["node-id-src"] == row["node-id-dst"]) + | (skan_df["node-id-dst"] == row["node-id-src"]) + ].index.tolist() + + # Remove the index of the current row itself from the list of connections + connections.remove(row.name) + return str(connections) + + def prep_arrays( image: npt.NDArray, labelled_grains_mask: npt.NDArray, pad_width: int ) -> tuple[dict[int, npt.NDArray], dict[int, npt.NDArray]]: @@ -522,11 +648,16 @@ def disordered_trace_grain( # pylint: disable=too-many-arguments "smoothed_grain": disorderedtrace.smoothed_mask, "skeleton": disorderedtrace.skeleton, "pruned_skeleton": disorderedtrace.pruned_skeleton, - "branch_types": get_branch_type_image(cropped_image, disorderedtrace.pruned_skeleton), + "branch_types": get_skan_image( + cropped_image, disorderedtrace.pruned_skeleton, "branch-type" + ), # change with Skan new release + "branch_indexes": get_skan_image( + cropped_image, disorderedtrace.pruned_skeleton, "node-id-src" + ), # change with Skan new release } -def get_branch_type_image(original_image: npt.NDArray, pruned_skeleton: npt.NDArray) -> npt.NDArray: +def get_skan_image(original_image: npt.NDArray, pruned_skeleton: npt.NDArray, skan_column: str) -> npt.NDArray: """ Label each branch with it's Skan branch type label. @@ -542,22 +673,26 @@ def get_branch_type_image(original_image: npt.NDArray, pruned_skeleton: npt.NDAr Height image from which the pruned skeleton is derived from. pruned_skeleton : npt.NDArray Single pixel thick skeleton mask. + skan_column : str + A column from Skan's summarize function to colour the branch segments with. Returns ------- npt.NDArray 2D array where the background is 0, and skeleton branches label as their Skan branch type. """ - branch_type_image = np.zeros_like(original_image) + branch_field_image = np.zeros_like(original_image) skeleton_image = np.where(pruned_skeleton == 1, original_image, 0) skan_skeleton = skan.Skeleton(skeleton_image, spacing=1e-9, value_is_height=True) res = skan.summarize(skan_skeleton) - for i, branch_type in enumerate(res["branch-type"]): + for i, branch_field in enumerate(res[skan_column]): path_coords = skan_skeleton.path_coordinates(i) - branch_type_image[path_coords[:, 0], path_coords[:, 1]] = branch_type + 1 + if skan_column == "node-id-src": + branch_field = i + branch_field_image[path_coords[:, 0], path_coords[:, 1]] = branch_field + 1 - return branch_type_image + return branch_field_image def crop_array(array: npt.NDArray, bounding_box: tuple, pad_width: int = 0) -> npt.NDArray: diff --git a/topostats/tracing/dnatracing.py b/topostats/tracing/dnatracing.py index b606dbc9f1..b81e8217a0 100644 --- a/topostats/tracing/dnatracing.py +++ b/topostats/tracing/dnatracing.py @@ -213,7 +213,7 @@ def trace_dna(self): mask=self.mask, smoothed_mask=self.smoothed_mask, skeleton=self.pruned_skeleton, - px_2_nm=self.pixel_to_nm_scaling, + pixel_to_nm_scaling=self.pixel_to_nm_scaling, n_grain=self.n_grain, node_joining_length=self.joining_node_length, ) diff --git a/topostats/tracing/nodestats.py b/topostats/tracing/nodestats.py index 6098e2400f..c5b69f88cf 100644 --- a/topostats/tracing/nodestats.py +++ b/topostats/tracing/nodestats.py @@ -3,6 +3,7 @@ from __future__ import annotations import logging +from itertools import combinations from typing import TypedDict import networkx as nx @@ -40,7 +41,7 @@ class NodeDict(TypedDict): """Dictionary containing the node information.""" error: bool - px_2_nm: np.float64 + pixel_to_nm_scaling: np.float64 branch_stats: dict[int, MatchedBranch] | None node_coords: npt.NDArray[np.int32] | None confidence: np.float64 | None @@ -90,7 +91,7 @@ class nodeStats: A smoothed version of the bianary segmentation mask. skeleton : npt.NDArray A binary single-pixel wide mask of objects in the 'image'. - px_2_nm : np.float32 + pixel_to_nm_scaling : np.float32 The pixel to nm scaling factor. n_grain : int The grain number. @@ -111,7 +112,7 @@ def __init__( mask: npt.NDArray, smoothed_mask: npt.NDArray, skeleton: npt.NDArray, - px_2_nm: np.float64, + pixel_to_nm_scaling: np.float64, n_grain: int, node_joining_length: float, node_extend_dist: float, @@ -133,7 +134,7 @@ def __init__( A smoothed version of the bianary segmentation mask. skeleton : npt.NDArray A binary single-pixel wide mask of objects in the 'image'. - px_2_nm : float + pixel_to_nm_scaling : float The pixel to nm scaling factor. n_grain : int The grain number. @@ -151,10 +152,10 @@ def __init__( self.mask = mask self.smoothed_mask = smoothed_mask # only used to average traces self.skeleton = skeleton - self.px_2_nm = px_2_nm + self.pixel_to_nm_scaling = pixel_to_nm_scaling self.n_grain = n_grain self.node_joining_length = node_joining_length - self.node_extend_dist = node_extend_dist / self.px_2_nm + self.node_extend_dist = node_extend_dist / self.pixel_to_nm_scaling self.branch_pairing_length = branch_pairing_length self.pair_odd_branches = pair_odd_branches @@ -214,19 +215,19 @@ def get_node_stats(self) -> tuple: |-> 'grain_mask' â””-> 'grain_skeleton' """ - LOGGER.info(f"Node Stats - Processing Grain: {self.n_grain}") + LOGGER.debug(f"Node Stats - Processing Grain: {self.n_grain}") self.conv_skelly = convolve_skeleton(self.skeleton) if len(self.conv_skelly[self.conv_skelly == 3]) != 0: # check if any nodes - LOGGER.info(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.") + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} contains crossings.") # convolve to see crossing and end points - self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image) + # self.conv_skelly = self.tidy_branches(self.conv_skelly, self.image) # reset skeleton var as tidy branches may have modified it self.skeleton = np.where(self.conv_skelly != 0, 1, 0) self.image_dict["grain"]["grain_skeleton"] = self.skeleton # get graph of skeleton self.whole_skel_graph = self.skeleton_image_to_graph(self.skeleton) # connect the close nodes - LOGGER.info(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.") + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} connecting close nodes.") self.connected_nodes = self.connect_close_nodes(self.conv_skelly, node_width=self.node_joining_length) # connect the odd-branch nodes self.connected_nodes = self.connect_extended_nodes_nearest( @@ -235,11 +236,11 @@ def get_node_stats(self) -> tuple: # obtain a mask of node centers and their count self.node_centre_mask = self.highlight_node_centres(self.connected_nodes) # Begin the hefty crossing analysis - LOGGER.info(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.") + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} analysing found crossings.") self.analyse_nodes(max_branch_length=self.branch_pairing_length) self.compile_metrics() else: - LOGGER.info(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.") + LOGGER.debug(f"[{self.filename}] : Nodestats - {self.n_grain} has no crossings.") return self.node_dicts, self.image_dict # self.all_visuals_img = dnaTrace.concat_images_in_dict(self.image.shape, self.visuals) @@ -330,7 +331,7 @@ def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> n node_centre = coords.mean(axis=0).astype(np.int32) node_wid = coords[:, 0].max() - coords[:, 0].min() + 2 # +2 so always 2 by default node_len = coords[:, 1].max() - coords[:, 1].min() + 2 # +2 so always 2 by default - overflow = int(10 / self.px_2_nm) if int(10 / self.px_2_nm) != 0 else 1 + overflow = int(10 / self.pixel_to_nm_scaling) if int(10 / self.pixel_to_nm_scaling) != 0 else 1 # grain mask fill new_skeleton[ node_centre[0] - node_wid // 2 - overflow : node_centre[0] + node_wid // 2 + overflow, @@ -346,7 +347,9 @@ def tidy_branches(self, connect_node_mask: npt.NDArray, image: npt.NDArray) -> n # new_skeleton = pruneSkeleton(image, new_skeleton).prune_skeleton( # {"method": "topostats", "max_length": -1} # ) - new_skeleton = prune_skeleton(image, new_skeleton, **{"method": "topostats", "max_length": -1}) + new_skeleton = prune_skeleton( + image, new_skeleton, self.pixel_to_nm_scaling, **{"method": "topostats", "max_length": -1} + ) # cleanup around nibs new_skeleton = getSkeleton(image, new_skeleton, method="zhang").get_skeleton() # might also need to remove segments that have squares connected @@ -374,7 +377,7 @@ def keep_biggest_object(mask: npt.NDArray) -> npt.NDArray: max_idx = idxs[np.argmax(counts[1:]) + 1] return np.where(labelled_mask == max_idx, 1, 0) except ValueError as e: - LOGGER.info(f"{e}: mask is empty.") + LOGGER.debug(f"{e}: mask is empty.") return mask def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85) -> npt.NDArray: @@ -400,7 +403,7 @@ def connect_close_nodes(self, conv_skelly: npt.NDArray, node_width: float = 2.85 nodeless[(nodeless == 3) | (nodeless == 2)] = 0 # remove node & termini points nodeless_labels = label(nodeless) for i in range(1, nodeless_labels.max() + 1): - if nodeless[nodeless_labels == i].size < (node_width / self.px_2_nm): + if nodeless[nodeless_labels == i].size < (node_width / self.pixel_to_nm_scaling): # maybe also need to select based on height? and also ensure small branches classified self.connected_nodes[nodeless_labels == i] = 3 @@ -551,7 +554,7 @@ def analyse_nodes(self, max_branch_length: float = 20) -> None: error = False # Get branches relevant to the node - max_length_px = max_branch_length / (self.px_2_nm * 1) + max_length_px = max_branch_length / (self.pixel_to_nm_scaling * 1) reduced_node_area: npt.NDArray[np.int32] = nodeStats.only_centre_branches( self.connected_nodes, np.array([node_x, node_y]) ) @@ -569,14 +572,14 @@ def analyse_nodes(self, max_branch_length: float = 20) -> None: # Stop processing if nib (node has 2 branches) if branch_start_coords.shape[0] <= 2: - LOGGER.info( + LOGGER.debug( f"node {node_no} has only two branches - skipped & nodes removed.{len(node_coords)}" "pixels in nib node." ) else: try: real_node_count += 1 - LOGGER.info(f"Node: {real_node_count}") + LOGGER.debug(f"Node: {real_node_count}") # Analyse the node branches ( @@ -585,10 +588,10 @@ def analyse_nodes(self, max_branch_length: float = 20) -> None: ordered_branches, masked_image, branch_under_over_order, - conf, + confidence, singlet_branch_vectors, ) = nodeStats.analyse_node_branches( - p_to_nm=self.px_2_nm, + p_to_nm=self.pixel_to_nm_scaling, reduced_node_area=reduced_node_area, branch_start_coords=branch_start_coords, max_length_px=max_length_px, @@ -641,16 +644,16 @@ def analyse_nodes(self, max_branch_length: float = 20) -> None: # angles_between_vectors_along_branch except ResolutionError: - LOGGER.info(f"Node stats skipped as resolution too low: {self.px_2_nm}nm per pixel") + LOGGER.debug(f"Node stats skipped as resolution too low: {self.pixel_to_nm_scaling}nm per pixel") error = True self.node_dicts[f"node_{real_node_count}"] = { "error": error, - "px_2_nm": self.px_2_nm, + "pixel_to_nm_scaling": self.pixel_to_nm_scaling, "branch_stats": matched_branches, "unmatched_branch_stats": unmatched_branches, "node_coords": node_coords, - "confidence": conf, + "confidence": confidence, } assert reduced_node_area is not None, "Reduced node area is not defined." @@ -798,7 +801,7 @@ def analyse_node_branches( keys: - "ordered_coords" : npt.NDArray[np.int32]. - "heights" : npt.NDArray[np.number]. Heights of the branches. - - "distances" : + - "distances" : npt.NDArray[np.number]. The accumulating distance along the branch. - "fwhm" : npt.NDArray[np.number]. Full width half maximum of the branches. - "angles" : np.float64. The angle of the branch, added in later steps. ordered_branches: list[npt.NDArray[np.int32]] @@ -809,13 +812,11 @@ def analyse_node_branches( - "avg_mask" : npt.NDArray[np.bool_]. Average mask of the branches. branch_under_over_order: npt.NDArray[np.int32] The order of the branches based on the FWHM. - conf: np.float64 | None + confidence: np.float64 | None The confidence of the crossing. Optional. """ if not p_to_nm <= resolution_threshold: - LOGGER.warning( - f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse." - ) + LOGGER.debug(f"Resolution {p_to_nm} is below suggested {resolution_threshold}, node difficult to analyse.") # Pixel-wise order the branches coming from the node and calculate the starting vector for each branch ordered_branches, singlet_branch_vectors = nodeStats.get_ordered_branches_and_vectors( @@ -847,20 +848,17 @@ def analyse_node_branches( values["fwhm"] = nodeStats.calculate_fwhm(values["heights"], values["distances"], hm=max(hms)) # Get the confidence of the crossing - crossing_quants = [] + crossing_fwhms = [] for _, values in matched_branches.items(): - crossing_quants.append(values["fwhm"]["fwhm"]) - if len(crossing_quants) <= 1: - conf = None + crossing_fwhms.append(values["fwhm"]["fwhm"]) + if len(crossing_fwhms) <= 1: + confidence = None else: - combs = nodeStats.get_two_combinations(crossing_quants) - conf = np.float64(nodeStats.cross_confidence(combs)) + crossing_fwhm_combinations = list(combinations(crossing_fwhms, 2)) + confidence = np.float64(nodeStats.cross_confidence(crossing_fwhm_combinations)) - fwhms = [] - for _, values in matched_branches.items(): - fwhms.append(values["fwhm"]["fwhm"]) # Order the branch indexes based on the FWHM of the branches. - branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(fwhms))] + branch_under_over_order = np.array(list(matched_branches.keys()))[np.argsort(np.array(crossing_fwhms))] return ( pairs, @@ -868,7 +866,7 @@ def analyse_node_branches( ordered_branches, masked_image, branch_under_over_order, - conf, + confidence, singlet_branch_vectors, ) @@ -963,7 +961,7 @@ def join_matching_branches_through_node( AssertionError, IndexError, ) as e: # Assertion - avg trace not advised, Index - wiggy branches - LOGGER.info(f"[{filename}] : avg trace failed with {e}, single trace only.") + LOGGER.debug(f"[{filename}] : avg trace failed with {e}, single trace only.") average_trace_advised = False distances = nodeStats.coord_dist_rad(single_branch_coords, np.array([node_coords[0], node_coords[1]])) # distances = self.coord_dist(single_branch_coords) @@ -1030,37 +1028,14 @@ def get_ordered_branches_and_vectors( return ordered_branches, vectors @staticmethod - def get_two_combinations(fwhm_list) -> list: - """ - Obtain all paired combinations of values in the list. - - Example: [1,2] -> [[1,2]], [1,2,3] -> [[1,2],[1,3],[2,3]] - - Parameters - ---------- - fwhm_list : list - List of FWHMs from crossing analysis. - - Returns - ------- - list - A list of pairs of 'fwhm_list' values. - """ - combs = [] - for i in range(len(fwhm_list) - 1): - for j in fwhm_list[i + 1 :]: - combs.append([fwhm_list[i], j]) - return combs - - @staticmethod - def cross_confidence(combs: list) -> float: + def cross_confidence(pair_combinations: list) -> float: """ Obtain the average confidence of the combinations using a reciprical function. Parameters ---------- - combs : list - List of combinations of FWHM values. + pair_combinations : list + List of length 2 combinations of FWHM values. Returns ------- @@ -1068,9 +1043,9 @@ def cross_confidence(combs: list) -> float: The average crossing confidence. """ c = 0 - for comb in combs: - c += nodeStats.recip(comb) - return c / len(combs) + for pair in pair_combinations: + c += nodeStats.recip(pair) + return c / len(pair_combinations) @staticmethod def recip(vals: list) -> float: @@ -1447,7 +1422,7 @@ def binary_line(start: npt.NDArray, end: npt.NDArray) -> npt.NDArray: return arr @staticmethod - def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, px_2_nm: float = 1) -> npt.NDArray: + def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray: """ Calculate the distance from the centre coordinate to a point along the ordered coordinates. @@ -1460,7 +1435,7 @@ def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, px_2_nm: float = 1) Nx2 array of branch coordinates. centre : npt.NDArray A 1x2 array of the centre coordinates to identify a 0 point for the node. - px_2_nm : float, optional + pixel_to_nm_scaling : float, optional The pixel to nanometer scaling factor to provide real units, by default 1. Returns @@ -1475,7 +1450,7 @@ def coord_dist_rad(coords: npt.NDArray, centre: npt.NDArray, px_2_nm: float = 1) cross_idx = np.argwhere(np.all(coords == centre, axis=1)) rad_dist = np.sqrt(diff_coords[:, 0] ** 2 + diff_coords[:, 1] ** 2) rad_dist[0 : cross_idx[0][0]] *= -1 - return rad_dist * px_2_nm + return rad_dist * pixel_to_nm_scaling @staticmethod def above_below_value_idx(array: npt.NDArray, value: float) -> list: @@ -1802,9 +1777,9 @@ def average_crossing_confs(node_dict) -> None | float: sum_conf = 0 valid_confs = 0 for _, (_, values) in enumerate(node_dict.items()): - conf = values["confidence"] - if conf is not None: - sum_conf += conf + confidence = values["confidence"] + if confidence is not None: + sum_conf += confidence valid_confs += 1 try: return sum_conf / valid_confs @@ -1826,15 +1801,15 @@ def minimum_crossing_confs(node_dict: dict) -> None | float: Union[None, float] The value of minimum confidence or none if not possible. """ - confs = [] + confidences = [] valid_confs = 0 for _, (_, values) in enumerate(node_dict.items()): - conf = values["confidence"] - if conf is not None: - confs.append(conf) + confidence = values["confidence"] + if confidence is not None: + confidences.append(confidence) valid_confs += 1 try: - return min(confs) + return min(confidences) except ValueError: return None @@ -1906,7 +1881,7 @@ def nodestats_image( nodestats_branch_images = {} grainstats_additions = {} - LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains.") + LOGGER.info(f"[{filename}] : Calculating NodeStats statistics for {n_grains} grains...") for n_grain, disordered_tracing_grain_data in disordered_tracing_direction_data.items(): nodestats = None # reset the nodestats variable @@ -1916,7 +1891,7 @@ def nodestats_image( mask=disordered_tracing_grain_data["original_grain"], smoothed_mask=disordered_tracing_grain_data["smoothed_grain"], skeleton=disordered_tracing_grain_data["pruned_skeleton"], - px_2_nm=pixel_to_nm_scaling, + pixel_to_nm_scaling=pixel_to_nm_scaling, filename=filename, n_grain=n_grain, node_joining_length=node_joining_length, @@ -1925,7 +1900,7 @@ def nodestats_image( pair_odd_branches=pair_odd_branches, ) nodestats_dict, node_image_dict = nodestats.get_node_stats() - LOGGER.info(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}") + LOGGER.debug(f"[{filename}] : Nodestats processed {n_grain} of {n_grains}") # compile images nodestats_images = { @@ -1951,7 +1926,10 @@ def nodestats_image( full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] except Exception as e: # pylint: disable=broad-exception-caught - LOGGER.error(f"[{filename}] : Nodestats for {n_grain} failed with - {e}") + LOGGER.error( + f"[{filename}] : Nodestats for {n_grain} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) nodestats_data[n_grain] = {} # turn the grainstats additions into a dataframe, # might need to do something for when everything is empty diff --git a/topostats/tracing/ordered_tracing.py b/topostats/tracing/ordered_tracing.py index efb28341f1..514f94517a 100644 --- a/topostats/tracing/ordered_tracing.py +++ b/topostats/tracing/ordered_tracing.py @@ -3,11 +3,13 @@ from __future__ import annotations import logging +from itertools import combinations import numpy as np import numpy.typing as npt import pandas as pd -from skimage.morphology import label +from skimage.morphology import binary_dilation, label +from topoly import jones, translate_code from topostats.logs.logs import LOGGER_NAME from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch, reorderTrace @@ -62,9 +64,7 @@ def __init__( "num_mols": 0, "circular": None, } - self.mol_tracing_stats = {"circular": None} - - self.ordered_traces = None + self.mol_tracing_stats = {"circular": None, "topology": None, "topology_flip": None, "processing": "nodestats"} self.images = { "over_under": np.zeros_like(image), @@ -75,23 +75,29 @@ def __init__( self.profiles = {} + self.img_idx_to_node = {} + self.ordered_coordinates = [] # pylint: disable=too-many-locals - def compile_trace(self) -> tuple[list, npt.NDArray]: # noqa: C901 + # pylint: disable=too-many-branches + def compile_trace(self, reverse_min_conf_crossing: bool = False) -> tuple[list, npt.NDArray]: # noqa: C901 """ Obtain the trace and diagnostic crossing trace and molecule trace images. This function uses the branches and full-width half-maximums (FWHMs) identified in the node_stats dictionary to create a continuous trace of the molecule. + Parameters + ---------- + reverse_min_conf_crossing : bool + Whether to reverse the stacking order of the lowest confidence crossing in the trace. + Returns ------- tuple[list, npt.NDArray] - A list of each complete path's ordered coordinates, and labeled crosing image array. + A list of each complete path's ordered coordinates, and labeled crossing image array. """ - LOGGER.info(f"[{self.filename}] : Compiling the trace.") - # iterate through the dict to get branch coords, heights and fwhms node_coords = [ [stats["node_coords"] for branch_stats in stats["branch_stats"].values() if branch_stats["fwhm"]["fwhm"]] @@ -119,35 +125,59 @@ def compile_trace(self) -> tuple[list, npt.NDArray]: # noqa: C901 ] fwhms = [lst for lst in fwhms if lst] + confidences = [stats["confidence"] for stats in self.nodestats_dict.values()] + + # obtain the index of the underlying branch + try: + low_conf_idx = np.nanargmin(np.array(confidences, dtype=float)) + except ValueError: # when no crossings or only 3-branch crossings + low_conf_idx = None + # Get the image minus the crossing regions - minus = self.skeleton.copy() + nodes = np.zeros_like(self.skeleton) + for node_no in node_coords: # this stops unpaired branches from interacting with the pairs + nodes[node_no[0][:, 0], node_no[0][:, 1]] = 1 + minus = np.where(binary_dilation(binary_dilation(nodes)) == self.skeleton, 0, self.skeleton) + # remove crossings from skeleton for crossings in crossing_coords: for crossing in crossings: minus[crossing[:, 0], crossing[:, 1]] = 0 minus = label(minus) - # Get both image - both = minus.copy() - for node_num, crossings in enumerate(crossing_coords): - for crossing_num, crossing in enumerate(crossings): - both[crossing[:, 0], crossing[:, 1]] = node_num + crossing_num + minus.max() + # setup z array + z = [] # order minus segments ordered = [] - for i in range(1, minus.max() + 1): - arr = np.where(minus, minus == i, 0) + for non_cross_segment_idx in range(1, minus.max() + 1): + arr = np.where(minus, minus == non_cross_segment_idx, 0) ordered.append(order_branch(arr, [0, 0])) # orientated later + z.append(0) + self.img_idx_to_node[non_cross_segment_idx] = {} # add crossing coords to ordered segment list - for i, node_crossing_coords in enumerate(crossing_coords): - for j, single_cross in enumerate(node_crossing_coords): + uneven_count = non_cross_segment_idx + 1 + for node_num, node_crossing_coords in enumerate(crossing_coords): + z_idx = np.argsort(fwhms[node_num]) + z_idx[z_idx == 0] = -1 + if reverse_min_conf_crossing and low_conf_idx == node_num: + z_idx = z_idx[::-1] + fwhms[node_num] = fwhms[node_num][::-1] + for node_cross_idx, single_cross in enumerate(node_crossing_coords): # check current single cross has no duplicate coords with ordered, except crossing points uncommon_single_cross = np.array(single_cross).copy() for coords in ordered: uncommon_single_cross = self.remove_common_values( - uncommon_single_cross, np.array(coords), retain=node_coords[i][j] + uncommon_single_cross, np.array(coords), retain=node_coords[node_num][node_cross_idx] ) if len(uncommon_single_cross) > 0: ordered.append(uncommon_single_cross) + z.append(z_idx[node_cross_idx]) + self.img_idx_to_node[uneven_count + node_cross_idx] = { + "node_idx": node_num, + "coords": single_cross, + "z_idx": z_idx[node_cross_idx], + } + uneven_count += len(node_crossing_coords) # get an image of each ordered segment cross_add = np.zeros_like(self.image) @@ -155,8 +185,32 @@ def compile_trace(self) -> tuple[list, npt.NDArray]: # noqa: C901 single_cross_img = coords_2_img(np.array(coords), cross_add) cross_add[single_cross_img != 0] = i + 1 - coord_trace = self.trace(ordered, cross_add) + coord_trace, simple_trace = self.trace(ordered, cross_add, z, n=100) + # obtain topology from the simple trace + topology = self.get_topology(simple_trace) + if reverse_min_conf_crossing and low_conf_idx is None: # when there's nothing to reverse + topology = [None for _ in enumerate(topology)] + + return coord_trace, topology, cross_add, crossing_coords, fwhms + + def compile_images(self, coord_trace: list, cross_add: npt.NDArray, crossing_coords: list, fwhms: list) -> None: + """ + Obtain all the diagnostic images based on the produced traces, and values. + + Crossing coords and fwhms are used as arguments as reversing the minimum confidence can modify these. + + Parameters + ---------- + coord_trace : list + List of N molecule objects containing 2xM arrays of X, Y coordinates. + cross_add : npt.NDArray + A labelled array with segments of the ordered trace. + crossing_coords : list + A list of I nodes objects containing 2xJ arrays of X, Y coordinates for each crossing branch. + fwhms : list + A list of I nodes objects containing FWHM values for each crossing branch. + """ # visual over under img self.images["trace_segments"] = cross_add try: @@ -166,8 +220,6 @@ def compile_trace(self) -> tuple[list, npt.NDArray]: # noqa: C901 pass self.images["ordered_traces"] = ordered_trace_mask(coord_trace, self.image.shape) - return coord_trace, self.images - @staticmethod def remove_common_values( ordered_array: npt.NDArray, common_value_check_array: npt.NDArray, retain: list = () @@ -201,7 +253,53 @@ def remove_common_values( return np.asarray(filtered_arr1) - def trace(self, ordered_segment_coords: list, both_img: npt.NDArray) -> list: + def get_topology(self, nxyz: npt.NDArray) -> list: + """ + Obtain a topological classification from ordered XYZ coordinates. + + Parameters + ---------- + nxyz : npt.NDArray + A 4xN array of the order index (n), x, y and pseudo z coordinates. + + Returns + ------- + list + Topology(s) of the provided traced coordinates. + """ + # Topoly doesn't work when 2 mols don't actually cross + topology = [] + lin_idxs = [] + nxyz_cp = nxyz.copy() + # remove linear mols as are just reidmiester moves + for i, mol_trace in enumerate(nxyz): + if mol_trace[-1][0] != 0: # mol is not circular + topology.append("linear") + lin_idxs.append(i) + # remove from list in reverse order so no conflicts + lin_idxs.sort(reverse=True) + for i in lin_idxs: + del nxyz_cp[i] + # classify topology for non-reidmeister moves + if len(nxyz_cp) != 0: + try: + pd_code = translate_code( + nxyz_cp, output_type="pdcode" + ) # pd code helps prevents freezing and spawning multiple processes + LOGGER.debug(f"{self.filename} : PD Code is: {pd_code}") + top_class = jones(pd_code) + except (IndexError, KeyError): + LOGGER.debug(f"{self.filename} : PD Code could not be obtained from trace coordinates.") + top_class = "N/A" + + # don't separate catenanes / overlaps - used for distribution comparison + for _ in range(len(nxyz_cp)): + topology.append(top_class) + + return topology + + def trace(self, ordered_segment_coords: list, both_img: npt.NDArray, zs: npt.NDArray, n: int = 100) -> list: + # pylint: disable=too-many-locals """ Obtain an ordered trace of each complete path. @@ -214,6 +312,11 @@ def trace(self, ordered_segment_coords: list, both_img: npt.NDArray) -> list: Ordered coordinates of each labeled segment in 'both_img'. both_img : npt.NDArray A skeletonised labeled image of each path segment. + zs : npt.NDArray + Array of pseudo heights of the traces. -1 is lowest, 0 is skeleton, then ascending integers for + levels of overs. + n : int + The number of points to use for the simplified traces. Returns ------- @@ -221,9 +324,11 @@ def trace(self, ordered_segment_coords: list, both_img: npt.NDArray) -> list: Ordered trace coordinates of each complete path. """ mol_coords = [] + simple_coords = [] remaining = both_img.copy().astype(np.int32) endpoints = np.unique(remaining[convolve_skeleton(remaining.astype(bool)) == 2]) # unique in case of whole mol prev_segment = None + n_points_p_seg = (n - 2 * remaining.max()) // remaining.max() while remaining.max() != 0: # select endpoint to start if there is one @@ -232,22 +337,87 @@ def trace(self, ordered_segment_coords: list, both_img: npt.NDArray) -> list: coord_idx = endpoints.pop(0) - 1 else: # if no endpoints, just a loop coord_idx = np.unique(remaining)[1] - 1 # avoid choosing 0 - coord_trace = np.empty((0, 2)).astype(np.int32) + coord_trace = np.empty((0, 3)).astype(np.int32) + simple_trace = np.empty((0, 3)).astype(np.int32) + while coord_idx > -1: # either cycled through all or hits terminus -> all will be just background remaining[remaining == coord_idx + 1] = 0 trace_segment = self.get_trace_segment(remaining, ordered_segment_coords, coord_idx) + full_trace_segment = trace_segment.copy() if len(coord_trace) > 0: # can only order when there's a reference point / segment trace_segment = self.remove_common_values( trace_segment, prev_segment ) # remove overlaps in trace (may be more efficient to do it on the previous segment) - trace_segment = self.order_from_end(coord_trace[-1], trace_segment) + trace_segment, flipped = self.order_from_end(coord_trace[-1, :2], trace_segment) + full_trace_segment = full_trace_segment[::-1] if flipped else full_trace_segment + # get vector if crossing + if self.img_idx_to_node[coord_idx + 1]: + segment_vector = full_trace_segment[-1] - full_trace_segment.mean( + axis=0 + ) # from start to mean coord + segment_vector /= np.sqrt(segment_vector @ segment_vector) # normalise + self.img_idx_to_node[coord_idx + 1]["vector"] = segment_vector prev_segment = trace_segment.copy() # update previous segment - coord_trace = np.append(coord_trace, trace_segment.astype(np.int32), axis=0) - x, y = coord_trace[-1] + trace_segment_z = np.column_stack( + (trace_segment, np.ones((len(trace_segment), 1)) * zs[coord_idx]) + ).astype( + np.int32 + ) # add z's + coord_trace = np.append(coord_trace, trace_segment_z.astype(np.int32), axis=0) + + # obtain a reduced coord version of the traces for Topoly + simple_trace_temp = self.reduce_rows( + trace_segment.astype(np.int32), n=n_points_p_seg + ) # reducing rows here ensures no segments are skipped + simple_trace_temp_z = np.column_stack( + (simple_trace_temp, np.ones((len(simple_trace_temp), 1)) * zs[coord_idx]) + ).astype( + np.int32 + ) # add z's + simple_trace = np.append(simple_trace, simple_trace_temp_z, axis=0) + + x, y = coord_trace[-1, :2] coord_idx = remaining[x - 1 : x + 2, y - 1 : y + 2].max() - 1 # should only be one value mol_coords.append(coord_trace) - return mol_coords + # Issue in 0_5 where wrong nxyz[0] selected, and == nxyz[-1] so always duplicated + nxyz = np.column_stack((np.arange(0, len(simple_trace)), simple_trace)) + end_to_end_dist_squared = (nxyz[0][1] - nxyz[-1][1]) ** 2 + (nxyz[0][2] - nxyz[-1][2]) ** 2 + if len(nxyz) > 2 and end_to_end_dist_squared <= 2: # pylint: disable=chained-comparison + # single coord traces mean nxyz[0]==[1] so cause issues when duplicating for topoly + nxyz = np.append(nxyz, nxyz[0][np.newaxis, :], axis=0) + simple_coords.append(nxyz) + + # convert into lists for Topoly + simple_coords = [[list(row) for row in mol] for mol in simple_coords] + + return mol_coords, simple_coords + + @staticmethod + def reduce_rows(array: npt.NDArray, n: int = 300) -> npt.NDArray: + """ + Reduce the number of rows in the array to `n`, keeping the first and last indexes. + + Parameters + ---------- + array : npt.NDArray + An array to reduce the number of rows in. + n : int, optional + The number of indexes in the array to keep, by default 300. + + Returns + ------- + npt.NDArray + The `array` reduced to only `n` + 2 elements, or if shorter, the same array. + """ + # removes reduces the number of rows (but keeping the first and last ones) + if array.shape[0] < n or array.shape[0] < 4: + return array + + idxs_to_keep = np.unique(np.linspace(0, array[1:-1].shape[0] - 1, n).astype(np.int32)) + new_array = array[1:-1][idxs_to_keep] + new_array = np.append(array[0][np.newaxis, :], new_array, axis=0) + return np.append(new_array, array[-1][np.newaxis, :], axis=0) @staticmethod def get_trace_segment(remaining_img: npt.NDArray, ordered_segment_coords: list, coord_idx: int) -> npt.NDArray: @@ -294,16 +464,18 @@ def order_from_end(last_segment_coord: npt.NDArray, current_segment: npt.NDArray ------- npt.NDArray The current segment orientated to follow on from the last. + bool + Whether the order has been flipped. """ start_xy = current_segment[0] dist = np.sum((start_xy - last_segment_coord) ** 2) ** 0.5 if dist <= np.sqrt(2): - return current_segment - return current_segment[::-1] + return current_segment, False + return current_segment[::-1], True def get_over_under_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> npt.NDArray: """ - Obtain a labeled image according to the main trace (=1), under (=2), over (=3). + Obtain a labelled image according to the main trace (=1), under (=2), over (=3). Parameters ---------- @@ -317,7 +489,7 @@ def get_over_under_img(self, coord_trace: list, fwhms: list, crossing_coords: li Returns ------- npt.NDArray - 2D crossing order labeled image. + 2D crossing order labelled image. """ # put down traces img = np.zeros_like(self.skeleton) @@ -325,10 +497,10 @@ def get_over_under_img(self, coord_trace: list, fwhms: list, crossing_coords: li temp_img = np.zeros_like(img) temp_img[coords[:, 0], coords[:, 1]] = 1 # temp_img = binary_dilation(temp_img) - img[temp_img != 0] = 1 # mol_no + 1 - lower_idxs, upper_idxs = self.get_trace_idxs(fwhms) + img[temp_img != 0] = 1 - # place over/unders onto image array + # place over/under strands onto image array + lower_idxs, upper_idxs = self.get_trace_idxs(fwhms) for i, type_idxs in enumerate([lower_idxs, upper_idxs]): for crossing, type_idx in zip(crossing_coords, type_idxs): temp_img = np.zeros_like(img) @@ -341,8 +513,9 @@ def get_over_under_img(self, coord_trace: list, fwhms: list, crossing_coords: li # pylint: disable=too-many-locals def get_mols_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> npt.NDArray: + # pylint: disable=too-many-locals """ - Obtain a labeled image according to each molecule traced N=3 -> n=1,2,3. + Obtain a labelled image according to each molecule traced N=3 -> n=1,2,3. Parameters ---------- @@ -356,7 +529,7 @@ def get_mols_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> Returns ------- npt.NDArray - 2D individual 'molecule' labeled image. + 2D individual 'molecule' labelled image. """ img = np.zeros_like(self.skeleton) for mol_no, coords in enumerate(coord_trace): @@ -376,7 +549,7 @@ def get_mols_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> c = 0 # get overlaps between segment coords and crossing under coords for cross_coord in cross_coords: - c += ((trace == cross_coord).sum(axis=1) == 2).sum() + c += ((trace[:, :2] == cross_coord).sum(axis=1) == 2).sum() matching_coords = np.append(matching_coords, c) val = matching_coords.argmax() + 1 temp_img[cross_coords[:, 0], cross_coords[:, 1]] = 1 @@ -387,7 +560,7 @@ def get_mols_img(self, coord_trace: list, fwhms: list, crossing_coords: list) -> @staticmethod def get_trace_idxs(fwhms: list) -> tuple[list, list]: """ - Split underpassing and overpassing indices. + Split under-passing and over-passing indices. Parameters ---------- @@ -422,6 +595,63 @@ def check_node_errorless(self) -> bool: return False return True + def identify_writhes(self) -> str | dict: + """ + Identify the writhe topology at each crossing in the image. + + Returns + ------- + str | dict + A string of the whole grain writhe sign, and a dictionary linking each node to it's sign. + """ + # compile all vectors for each node and their z_idx + # - want for each node, ordered vectors according to z_idx + writhe_string = "" + node_to_writhe = {} + idx2node_df = pd.DataFrame.from_dict(self.img_idx_to_node, orient="index") + if idx2node_df.empty: # for when no crossovers but still crossings (i.e. unpaired 3-way) + return "", {} + + for node_num, node_df in idx2node_df.groupby("node_idx"): + vector_series = node_df.sort_values(by=["z_idx"], ascending=False)["vector"] + vectors = list(vector_series) + # get pairs + vector_combinations = list(combinations(vectors, 2)) + # calculate the writhe + temp_writhes = "" + for vector_pair in vector_combinations: # if > 2 crossing branches + temp_writhes += self.writhe_direction(vector_pair[0], vector_pair[1]) + if len(temp_writhes) > 1: + temp_writhes = f"({temp_writhes})" + node_to_writhe[node_num] = temp_writhes + writhe_string += temp_writhes + + return writhe_string, node_to_writhe + + @staticmethod + def writhe_direction(first_vector: npt.NDArray, second_vector: npt.NDArray) -> str: + """ + Use the cross product of crossing vectors to determine the writhe sign. + + Parameters + ---------- + first_vector : npt.NDArray + An x,y component vector of the overlying strand. + second_vector : npt.NDArray + An x,y component vector of the underlying strand. + + Returns + ------- + str + '+', '-' or '0' for positive, negative, or no writhe. + """ + cross = np.cross(first_vector, second_vector) + if cross < 0: + return "-" + if cross > 0: + return "+" + return "0" + def run_nodestats_tracing(self) -> tuple[list, dict, dict]: """ Run the nodestats tracing pipeline. @@ -429,23 +659,38 @@ def run_nodestats_tracing(self) -> tuple[list, dict, dict]: Returns ------- tuple[list, dict, dict] - A list of each molecules ordered trace coordinates, the ordered_traicing stats, and the images. + A list of each molecules ordered trace coordinates, the ordered_tracing stats, and the images. """ - self.ordered_traces, self.images = self.compile_trace() - self.grain_tracing_stats["num_mols"] = len(self.ordered_traces) + ordered_traces, topology, cross_add, crossing_coords, fwhms = self.compile_trace( + reverse_min_conf_crossing=False + ) + self.compile_images(ordered_traces, cross_add, crossing_coords, fwhms) + self.grain_tracing_stats["num_mols"] = len(ordered_traces) + + writhe_string, node_to_writhes = self.identify_writhes() + self.grain_tracing_stats["writhe_string"] = writhe_string + for node_num, node_writhes in node_to_writhes.items(): # should self update as the dicts are linked + self.nodestats_dict[f"node_{node_num+1}"]["writhe"] = node_writhes + + topology_flip = self.compile_trace(reverse_min_conf_crossing=True)[1] ordered_trace_data = {} - for i, mol_trace in enumerate(self.ordered_traces): + grain_mol_tracing_stats = {} + for i, mol_trace in enumerate(ordered_traces): if len(mol_trace) > 3: # if > 4 coords to trace - self.mol_tracing_stats["circular"] = linear_or_circular(mol_trace) + np.save(f"trace_xyz_{i}", mol_trace) + self.mol_tracing_stats["circular"] = linear_or_circular(mol_trace[:, :2]) + self.mol_tracing_stats["topology"] = topology[i] + self.mol_tracing_stats["topology_flip"] = topology_flip[i] ordered_trace_data[f"mol_{i}"] = { - "ordered_coords": mol_trace, + "ordered_coords": mol_trace[:, :2], "heights": self.image[mol_trace[:, 0], mol_trace[:, 1]], - "distances": coord_dist(mol_trace[0]), + "distances": coord_dist(mol_trace[:, :2]), "mol_stats": self.mol_tracing_stats, } + grain_mol_tracing_stats[f"{i}"] = self.mol_tracing_stats - return ordered_trace_data, self.grain_tracing_stats, self.images + return ordered_trace_data, self.grain_tracing_stats, grain_mol_tracing_stats, self.images class OrderedTraceTopostats: @@ -481,7 +726,7 @@ def __init__( "num_mols": 1, "circular": None, } - self.mol_tracing_stats = {"circular": None} + self.mol_tracing_stats = {"circular": None, "topology": None, "topology_flip": None, "processing": "topostats"} self.images = { "ordered_traces": np.zeros_like(image), @@ -534,6 +779,7 @@ def run_topostats_tracing(self) -> tuple[list, dict, dict]: disordered_trace_coords = np.argwhere(self.skeleton == 1) self.mol_tracing_stats["circular"] = linear_or_circular(disordered_trace_coords) + self.mol_tracing_stats["topology"] = "0_1" if self.mol_tracing_stats["circular"] else "linear" ordered_trace = self.get_ordered_traces(disordered_trace_coords, self.mol_tracing_stats["circular"]) @@ -548,7 +794,7 @@ def run_topostats_tracing(self) -> tuple[list, dict, dict]: "mol_stats": self.mol_tracing_stats, } - return ordered_trace_data, self.grain_tracing_stats, self.images + return ordered_trace_data, self.grain_tracing_stats, {"0": self.mol_tracing_stats}, self.images def linear_or_circular(traces) -> bool: @@ -603,7 +849,7 @@ def ordered_trace_mask(ordered_coordinates: npt.NDArray, shape: tuple) -> npt.ND ordered_mask = np.zeros(shape) if isinstance(ordered_coordinates, list): for mol_coords in ordered_coordinates: - ordered_mask[mol_coords[:, 0], mol_coords[:, 1]] = np.arange(len(mol_coords)) + ordered_mask[mol_coords[:, 0], mol_coords[:, 1]] = np.arange(len(mol_coords)) + 1 return ordered_mask @@ -616,7 +862,8 @@ def ordered_tracing_image( filename: str, ordering_method: str, pad_width: int, -) -> tuple[dict, pd.DataFrame, dict]: +) -> tuple[dict, pd.DataFrame, pd.DataFrame, dict]: + # pylint: disable=too-many-locals """ Run ordered tracing for an entire image of >=1 grains. @@ -637,9 +884,9 @@ def ordered_tracing_image( Returns ------- - tuple[dict, pd.DataFrame, dict] + tuple[dict, pd.DataFrame, pd.DataFrame, dict] Results containing the ordered_trace_data (coordinates), any grain-level metrics to be added to the grains - dataframe, and the diagnostic images. + dataframe, a dataframe of molecule statistics and a dictionary of diagnostic images. """ ordered_trace_full_images = { "ordered_traces": np.zeros_like(image), @@ -648,52 +895,78 @@ def ordered_tracing_image( "trace_segments": np.zeros_like(image), } grainstats_additions = {} + molstats = {} all_traces_data = {} + LOGGER.info( + f"[{filename}] : Calculating Ordered Traces and statistics for " + + f"{len(disordered_tracing_direction_data)} grains..." + ) + # iterate through disordered_tracing_dict for grain_no, disordered_trace_data in disordered_tracing_direction_data.items(): - # try: - # check if want to do nodestats tracing or not - if grain_no in list(nodestats_direction_data["stats"].keys()) and ordering_method == "nodestats": - LOGGER.info(f"[{filename}] : Grain {grain_no} present in NodeStats. Tracing via Nodestats.") - nodestats_tracing = OrderedTraceNodestats( - image=nodestats_direction_data["images"][grain_no]["grain"]["grain_image"], - filename=filename, - nodestats_dict=nodestats_direction_data["stats"][grain_no], - skeleton=nodestats_direction_data["images"][grain_no]["grain"]["grain_skeleton"], - ) - if nodestats_tracing.check_node_errorless(): - ordered_traces_data, tracing_stats, images = nodestats_tracing.run_nodestats_tracing() - LOGGER.info(f"[{filename}] : Grain {grain_no} ordered via NodeStats.") + try: + # check if want to do nodestats tracing or not + if grain_no in list(nodestats_direction_data["stats"].keys()) and ordering_method == "nodestats": + LOGGER.debug(f"[{filename}] : Grain {grain_no} present in NodeStats. Tracing via Nodestats.") + nodestats_tracing = OrderedTraceNodestats( + image=nodestats_direction_data["images"][grain_no]["grain"]["grain_image"], + filename=filename, + nodestats_dict=nodestats_direction_data["stats"][grain_no], + skeleton=nodestats_direction_data["images"][grain_no]["grain"]["grain_skeleton"], + ) + if nodestats_tracing.check_node_errorless(): + ordered_traces_data, tracing_stats, grain_molstats, images = ( + nodestats_tracing.run_nodestats_tracing() + ) + LOGGER.debug(f"[{filename}] : Grain {grain_no} ordered via NodeStats.") + else: + LOGGER.debug(f"Nodestats dict has an error ({nodestats_direction_data['stats'][grain_no]['error']}") + # if not doing nodestats ordering, do original TS ordering else: - LOGGER.warning(f"Nodestats dict has an error ({nodestats_direction_data['stats'][grain_no]['error']}") - # if not doing nodestats ordering, do original TS ordering - else: - LOGGER.info(f"[{filename}] : {grain_no} not in NodeStats. Tracing normally.") - topostats_tracing = OrderedTraceTopostats( - image=disordered_trace_data["original_image"], - skeleton=disordered_trace_data["pruned_skeleton"], + LOGGER.debug(f"[{filename}] : {grain_no} not in NodeStats. Tracing normally.") + topostats_tracing = OrderedTraceTopostats( + image=disordered_trace_data["original_image"], + skeleton=disordered_trace_data["pruned_skeleton"], + ) + ordered_traces_data, tracing_stats, grain_molstats, images = topostats_tracing.run_topostats_tracing() + LOGGER.debug(f"[{filename}] : Grain {grain_no} ordered via TopoStats.") + + # compile traces + all_traces_data[grain_no] = ordered_traces_data + for mol_no, _ in ordered_traces_data.items(): + all_traces_data[grain_no][mol_no].update({"bbox": disordered_trace_data["bbox"]}) + # compile metrics + grainstats_additions[grain_no] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + } + tracing_stats.pop("circular") + grainstats_additions[grain_no].update(tracing_stats) + # compile molecule metrics + for mol_no, molstat_values in grain_molstats.items(): + molstats[f"{grain_no.split('_')[-1]}_{mol_no}"] = { + "image": filename, + "grain_number": int(grain_no.split("_")[-1]), + "molecule_number": int(mol_no.split("_")[-1]), # pylint: disable=use-maxsplit-arg + } + molstats[f"{grain_no.split('_')[-1]}_{mol_no}"].update(molstat_values) + + # remap the cropped images back onto the original + for image_name, full_image in ordered_trace_full_images.items(): + crop = images[image_name] + bbox = disordered_trace_data["bbox"] + full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] + + except Exception as e: # pylint: disable=broad-exception-caught + LOGGER.error( + f"[{filename}] : Ordered tracing for {grain_no} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, ) - ordered_traces_data, tracing_stats, images = topostats_tracing.run_topostats_tracing() - LOGGER.info(f"[{filename}] : Grain {grain_no} ordered via TopoStats.") - - # compile traces - all_traces_data[grain_no] = ordered_traces_data - for mol_no, _ in ordered_traces_data.items(): - all_traces_data[grain_no][mol_no].update({"bbox": disordered_trace_data["bbox"]}) - # compile metrics - grainstats_additions[grain_no] = { - "image": filename, - "grain_number": int(grain_no.split("_")[-1]), - } - tracing_stats.pop("circular") - grainstats_additions[grain_no].update(tracing_stats) - - # remap the cropped images back onto the original - for image_name, full_image in ordered_trace_full_images.items(): - crop = images[image_name] - bbox = disordered_trace_data["bbox"] - full_image[bbox[0] : bbox[2], bbox[1] : bbox[3]] += crop[pad_width:-pad_width, pad_width:-pad_width] + all_traces_data[grain_no] = {} + grainstats_additions_df = pd.DataFrame.from_dict(grainstats_additions, orient="index") + molstats_df = pd.DataFrame.from_dict(molstats, orient="index") + molstats_df.reset_index(drop=True, inplace=True) - return all_traces_data, grainstats_additions_df, ordered_trace_full_images + return all_traces_data, grainstats_additions_df, molstats_df, ordered_trace_full_images diff --git a/topostats/tracing/pruning.py b/topostats/tracing/pruning.py index a7a1bdc4e9..1051c039e6 100644 --- a/topostats/tracing/pruning.py +++ b/topostats/tracing/pruning.py @@ -11,13 +11,13 @@ from topostats.logs.logs import LOGGER_NAME from topostats.tracing.skeletonize import getSkeleton -from topostats.tracing.tracingfuncs import genTracingFuncs +from topostats.tracing.tracingfuncs import coord_dist, genTracingFuncs, order_branch from topostats.utils import convolve_skeleton LOGGER = logging.getLogger(LOGGER_NAME) -def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.NDArray: +def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: """ Pruning skeletons using different pruning methods. @@ -29,6 +29,8 @@ def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.N Original image as 2D numpy array. skeleton : npt.NDArray Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the respective method. @@ -39,10 +41,10 @@ def prune_skeleton(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.N """ if image.shape != skeleton.shape: raise AttributeError("Error image and skeleton are not the same size.") - return _prune_method(image, skeleton, **kwargs) + return _prune_method(image, skeleton, pixel_to_nm_scaling, **kwargs) -def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> Callable: +def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> Callable: """ Determine which skeletonize method to use. @@ -52,6 +54,8 @@ def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> Callab Original image as 2D numpy array. skeleton : npt.NDArray Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the respective method. @@ -67,7 +71,7 @@ def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> Callab """ method = kwargs.pop("method") if method == "topostats": - return _prune_topostats(image, skeleton, **kwargs) + return _prune_topostats(image, skeleton, pixel_to_nm_scaling, **kwargs) # @maxgamill-sheffield I've read about a "Discrete Skeleton Evolultion" (DSE) method that might be useful # @ns-rse (2024-06-04) : https://en.wikipedia.org/wiki/Discrete_skeleton_evolution # https://link.springer.com/chapter/10.1007/978-3-540-74198-5_28 @@ -76,7 +80,7 @@ def _prune_method(image: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> Callab raise ValueError(f"Invalid pruning method provided ({method}) please use one of 'topostats'.") -def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.NDArray: +def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, pixel_to_nm_scaling: float, **kwargs) -> npt.NDArray: """ Prune using the original TopoStats method. @@ -88,6 +92,8 @@ def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.N Image used to find skeleton, may be original heights or binary mask. skeleton : npt.NDArray Binary mask of the skeleton. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. **kwargs Pruning options passed to the topostatsPrune class. @@ -96,7 +102,7 @@ def _prune_topostats(img: npt.NDArray, skeleton: npt.NDArray, **kwargs) -> npt.N npt.NDArray The skeleton with spurious branches removed. """ - return topostatsPrune(img, skeleton, **kwargs).prune_skeleton() + return topostatsPrune(img, skeleton, pixel_to_nm_scaling, **kwargs).prune_skeleton() # class pruneSkeleton: pylint: disable=too-few-public-methods @@ -211,6 +217,8 @@ class topostatsPrune: Original image. skeleton : npt.NDArray Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float @@ -223,10 +231,12 @@ class topostatsPrune: skeleton mean - absolute threshold) or 'iqr' (below 1.5 * inter-quartile range). """ + # pylint: disable=too-many-arguments def __init__( self, img: npt.NDArray, skeleton: npt.NDArray, + pixel_to_nm_scaling: float, max_length: float = None, height_threshold: float = None, method_values: str = None, @@ -241,6 +251,8 @@ def __init__( Original image. skeleton : npt.NDArray Skeleton to be pruned. + pixel_to_nm_scaling : float + The pixel to nm scaling for pruning by length. max_length : float Maximum length of the branch to prune in nanometres (nm). height_threshold : float @@ -254,6 +266,7 @@ def __init__( """ self.img = img self.skeleton = skeleton.copy() + self.pixel_to_nm_scaling = pixel_to_nm_scaling self.max_length = max_length self.height_threshold = height_threshold self.method_values = method_values @@ -279,10 +292,10 @@ def prune_skeleton(self) -> npt.NDArray: for i in range(1, labeled_skel.max() + 1): single_skeleton = np.where(labeled_skel == i, 1, 0) if self.max_length is not None: - LOGGER.info("[pruning] : Pruning by length.") + LOGGER.debug(f": pruning.py : Pruning by length < {self.max_length}.") single_skeleton = self._prune_by_length(single_skeleton, max_length=self.max_length) if self.height_threshold is not None: - LOGGER.info("[pruning] : Pruning by height.") + LOGGER.debug(": pruning.py : Pruning by height.") single_skeleton = heightPruning( self.img, single_skeleton, @@ -301,7 +314,7 @@ def prune_skeleton(self) -> npt.NDArray: return pruned_skeleton_mask def _prune_by_length( # pylint: disable=too-many-locals # noqa: C901 - self, single_skeleton: npt.NDArray, max_length: float | int = -1 + self, single_skeleton: npt.NDArray, max_length: float ) -> npt.NDArray: """ Remove hanging branches from a skeleton by their length. @@ -312,71 +325,31 @@ def _prune_by_length( # pylint: disable=too-many-locals # noqa: C901 ---------- single_skeleton : npt.NDArray Binary array of the skeleton. - max_length : float | int - Maximum length of the branch to prune in nanometers (nm). Default is -1 which calculates a value that is 15% - of the total skeleton length. + max_length : float + Maximum length of the branch to prune in nanometers (nm). Returns ------- npt.NDArray Pruned skeleton as binary array. """ - pruning = True - while pruning: - single_skeleton = rm_nibs(single_skeleton) - n_branches = 0 - coordinates = np.argwhere(single_skeleton == 1).tolist() - - # The branches are typically short so if a branch is longer than - # 0.15 * total points, its assumed to be part of the real data - max_branch_length = max_length if max_length != -1 else int(len(coordinates) * 0.15) - LOGGER.info(f"[pruning] : Maximum branch length : {max_branch_length}") - # first check to find all the end coordinates in the trace - potential_branch_ends = self._find_branch_ends(coordinates) - - # Now check if its a branch - and if it is delete it - for branch_x, branch_y in potential_branch_ends: - branch_coordinates = [[branch_x, branch_y]] - branch_continues = True - temp_coordinates = coordinates[:] - temp_coordinates.pop(temp_coordinates.index([branch_x, branch_y])) - - while branch_continues: - n_neighbours, neighbours = genTracingFuncs.count_and_get_neighbours( - branch_x, branch_y, temp_coordinates - ) - - # If branch continues - if n_neighbours == 1: - branch_x, branch_y = neighbours[0] - branch_coordinates.append([branch_x, branch_y]) - temp_coordinates.pop(temp_coordinates.index([branch_x, branch_y])) - - # If the branch reaches the edge of the main trace - elif n_neighbours > 1: - branch_coordinates.pop(branch_coordinates.index([branch_x, branch_y])) - branch_continues = False - is_branch = True - - # Weird case that happens sometimes (would this be linear mols?) - elif n_neighbours == 0: - is_branch = True - branch_continues = False - - # why not `and branch_continues`? - if len(branch_coordinates) > max_branch_length: - branch_continues = False - is_branch = False - # - if is_branch: - n_branches += 1 - for x, y in branch_coordinates: - single_skeleton[x, y] = 0 - - if n_branches == 0: - pruning = False - - return single_skeleton + # get segments via convolution and removing junctions + conv_skeleton = convolve_skeleton(single_skeleton) + conv_skeleton[conv_skeleton == 3] = 0 + labeled_segments = morphology.label(conv_skeleton.astype(bool)) + + for segment_idx in range(1, labeled_segments.max() + 1): + # get single segment with endpoints==2 + segment = np.where(labeled_segments == segment_idx, conv_skeleton, 0) + # get segment length + ordered_coords = order_branch(np.where(segment != 0, 1, 0), [0, 0]) + segment_length = coord_dist(ordered_coords, self.pixel_to_nm_scaling)[-1] / 1e-9 + # check if endpoint + if 2 in segment and segment_length < max_length: + # prune + single_skeleton[labeled_segments == segment_idx] = 0 + + return rm_nibs(single_skeleton) @staticmethod def _find_branch_ends(coordinates: list) -> list: diff --git a/topostats/tracing/splining.py b/topostats/tracing/splining.py index 38c45df76b..9290d20970 100644 --- a/topostats/tracing/splining.py +++ b/topostats/tracing/splining.py @@ -21,6 +21,7 @@ # pylint: disable=too-many-instance-attributes class splineTrace: + # pylint: disable=too-many-instance-attributes """ Smooth the ordered trace via an average of splines. @@ -54,6 +55,7 @@ def __init__( spline_circular_smoothing: float, spline_degree: int, ) -> None: + # pylint: disable=too-many-arguments """ Initialise the splineTrace class. @@ -124,7 +126,7 @@ def get_splined_traces( # If the fitted trace is less than the degree plus one, then there is no # point in trying to spline it, just return the fitted trace if fitted_trace_length < self.spline_degree + 1: - LOGGER.warning( + LOGGER.debug( f"Fitted trace for grain {step_size_px} too small ({fitted_trace_length}), returning fitted trace" ) @@ -527,7 +529,9 @@ def splining_image( spline_linear_smoothing: float, spline_circular_smoothing: float, spline_degree: int, -) -> tuple[dict, pd.DataFrame]: +) -> tuple[dict, pd.DataFrame, pd.DataFrame]: + # pylint: disable=too-many-arguments + # pylint: disable=too-many-locals """ Obtain smoothed traces of pixel-wise ordered traces for molecules in an image. @@ -557,13 +561,19 @@ def splining_image( Returns ------- - tuple[dict, pd.DataFrame] - A spline data dictionary for all molecules, and a grainstats dataframe additions dataframe. + tuple[dict, pd.DataFrame, pd.DataFrame] + A spline data dictionary for all molecules, and a grainstats dataframe additions dataframe and molecule + statistics dataframe. """ grainstats_additions = {} molstats = {} all_splines_data = {} + mol_count = 0 + for mol_trace_data in ordered_tracing_direction_data.values(): + mol_count += len(mol_trace_data) + LOGGER.info(f"[{filename}] : Calculating Splining statistics for {mol_count} molecules...") + # iterate through disordered_tracing_dict for grain_no, ordered_grain_data in ordered_tracing_direction_data.items(): grain_trace_stats = {"total_contour_length": 0, "average_end_to_end_distance": 0} @@ -571,7 +581,7 @@ def splining_image( mol_no = None for mol_no, mol_trace_data in ordered_grain_data.items(): try: - LOGGER.info(f"[{filename}] : Splining {grain_no} - {mol_no}") + LOGGER.debug(f"[{filename}] : Splining {grain_no} - {mol_no}") # check if want to do nodestats tracing or not if method == "rolling_window": splined_data, tracing_stats = windowTrace( @@ -604,13 +614,17 @@ def splining_image( } molstats[grain_no.split("_")[-1] + "_" + mol_no.split("_")[-1]] = { "image": filename, - "grain_number": grain_no.split("_")[-1], + "grain_number": int(grain_no.split("_")[-1]), + "molecule_number": int(mol_no.split("_")[-1]), } molstats[grain_no.split("_")[-1] + "_" + mol_no.split("_")[-1]].update(tracing_stats) - LOGGER.info(f"[{filename}] : Finished splining {grain_no} - {mol_no}") + LOGGER.debug(f"[{filename}] : Finished splining {grain_no} - {mol_no}") except Exception as e: # pylint: disable=broad-exception-caught - LOGGER.error(f"[{filename}] : Ordered tracing for {grain_no} failed with - {e}") + LOGGER.error( + f"[{filename}] : Splining for {grain_no} failed. Consider raising an issue on GitHub. Error: ", + exc_info=e, + ) all_splines_data[grain_no] = {} if mol_no is None: diff --git a/topostats/tracing/tracingfuncs.py b/topostats/tracing/tracingfuncs.py index bdf7c20211..89aca0e816 100644 --- a/topostats/tracing/tracingfuncs.py +++ b/topostats/tracing/tracingfuncs.py @@ -506,7 +506,7 @@ def local_area_sum(binary_map: npt.NDArray, point: list | tuple | npt.NDArray) - @staticmethod -def coord_dist(coords: npt.NDArray, px_2_nm: float = 1) -> npt.NDArray: +def coord_dist(coords: npt.NDArray, pixel_to_nm_scaling: float = 1) -> npt.NDArray: """ Accumulate a real distance traversing from pixel to pixel from a list of coordinates. @@ -514,7 +514,7 @@ def coord_dist(coords: npt.NDArray, px_2_nm: float = 1) -> npt.NDArray: ---------- coords : npt.NDArray A Nx2 integer array corresponding to the ordered coordinates of a binary trace. - px_2_nm : float + pixel_to_nm_scaling : float The pixel to nanometer scaling factor. Returns @@ -530,4 +530,4 @@ def coord_dist(coords: npt.NDArray, px_2_nm: float = 1) -> npt.NDArray: else: dist += 1 dist_list.append(dist) - return np.asarray(dist_list) * px_2_nm + return np.asarray(dist_list) * pixel_to_nm_scaling diff --git a/topostats/validation.py b/topostats/validation.py index d7449358cf..fb1a2e596c 100644 --- a/topostats/validation.py +++ b/topostats/validation.py @@ -243,7 +243,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "topostats", error="Invalid value in config for 'disordered_tracing.pruning_method', valid values are 'topostats'", ), - "max_length": Or(int, float, None), + "max_length": lambda n: n >= 0, "method_values": Or("min", "median", "mid"), "method_outlier": Or("abs", "mean_abs", "iqr"), "height_threshold": Or(int, float, None), @@ -729,6 +729,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "labelled_regions_01": { "filename": str, @@ -747,6 +748,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "tidied_border": { "filename": str, @@ -764,6 +766,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "removed_noise": { "filename": str, @@ -775,6 +778,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "Invalid value in config 'removed_noise.image_type', valid values " "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -793,6 +797,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -827,6 +832,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -845,6 +851,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "are 'binary' or 'non-binary'" ), ), + "mask_cmap": str, "core_set": bool, "savefig_dpi": Or( lambda n: n > 0, @@ -885,6 +892,7 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "figure", error="Invalid value in config for 'dpi', valid values are 'figure' or > 0.", ), + "mask_cmap": str, }, "grain_image": { "image_type": Or( @@ -983,6 +991,20 @@ def validate_config(config: dict, schema: Schema, config_type: str) -> None: "core_set": bool, "savefig_dpi": int, }, + "branch_indexes": { + "filename": str, + "title": str, + "image_type": Or( + "binary", + "non-binary", + error=( + "Invalid value in config 'branch_indexes.image_type', valid values " "are 'binary' or 'non-binary'" + ), + ), + "mask_cmap": str, + "core_set": bool, + "savefig_dpi": int, + }, "branch_types": { "filename": str, "title": str,