Skip to content

Commit

Permalink
Test OlderMonocotPipeline and get_grav_index
Browse files Browse the repository at this point in the history
  • Loading branch information
eberrigan committed Aug 31, 2023
1 parent 797561b commit aa59e2b
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 3 deletions.
2 changes: 1 addition & 1 deletion sleap_roots/trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -1249,7 +1249,7 @@ def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, A
- "main_pts": Array of main root points.
"""
# Get the root instances.
main = plant[frame_idx]
main = plant[frame_idx][0]
gt_instances_lr = main.user_instances + main.unused_predictions

# Convert the instances to numpy arrays.
Expand Down
26 changes: 26 additions & 0 deletions tests/test_lengths.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,32 @@ def test_get_grav_index(canola_h5):
np.testing.assert_almost_equal(grav_index, 0.08898137324716636)


def test_grav_index_float():
assert get_grav_index(10.0, 5.0) == 0.5


def test_grav_index_float_invalid():
assert np.isnan(get_grav_index(np.nan, 5.0))


def test_grav_index_array():
lengths = np.array([10, 20, 30, 0, np.nan])
base_tip_dists = np.array([5, 15, 25, 0, np.nan])
expected = np.array([0.5, 0.25, 0.16666667, np.nan, np.nan])
np.testing.assert_allclose(
get_grav_index(lengths, base_tip_dists), expected, rtol=1e-6
)


def test_grav_index_mixed_invalid():
lengths = np.array([10, np.nan, 0])
base_tip_dists = np.array([5, 5, 5])
expected = np.array([0.5, np.nan, np.nan])
np.testing.assert_allclose(
get_grav_index(lengths, base_tip_dists), expected, rtol=1e-6
)


def test_get_root_lengths(canola_h5):
series = Series.load(
canola_h5, primary_name="primary_multi_day", lateral_name="lateral_3_nodes"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_trait_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def test_dicot_pipeline(canola_h5, soy_h5):


def test_OlderMonocot_pipeline(rice_main_10do_h5):
rice = Series.load(rice_main_10do_h5, ["main_10do_6nodes"])[0]
rice = Series.load(rice_main_10do_h5, ["main_10do_6nodes"])

pipeline = OlderMonocotPipeline()
rice_10dag_traits = pipeline.compute_plant_traits(rice)

assert rice_10dag_traits.shape == (72, 115)
assert rice_10dag_traits.shape == (72, 98)

0 comments on commit aa59e2b

Please sign in to comment.