-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multiple dicot pipeline #71
Add multiple dicot pipeline #71
Conversation
WalkthroughThis update encompasses significant enhancements across the project, focusing on improved data management with Git LFS, expanded root analysis functionalities in Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (invoked as PR comments)
Additionally, you can add CodeRabbit Configration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 1
Configuration used: CodeRabbit UI
Files ignored due to path filters (1)
tests/data/multiple_arabidopsis_11do/merged_proofread_samples_03122024.csv
is excluded by:!**/*.csv
Files selected for processing (24)
- .gitattributes (1 hunks)
- sleap_roots/init.py (1 hunks)
- sleap_roots/lengths.py (2 hunks)
- sleap_roots/points.py (2 hunks)
- sleap_roots/series.py (9 hunks)
- sleap_roots/trait_pipelines.py (7 hunks)
- tests/data/multiple_arabidopsis_11do/6039_1.h5 (1 hunks)
- tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/7327_2.h5 (1 hunks)
- tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/9535_1.h5 (1 hunks)
- tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/997_1.h5 (1 hunks)
- tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp (1 hunks)
- tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp (1 hunks)
- tests/fixtures/data.py (1 hunks)
- tests/test_lengths.py (2 hunks)
- tests/test_points.py (2 hunks)
- tests/test_trait_pipelines.py (2 hunks)
Files skipped from review due to trivial changes (15)
- .gitattributes
- tests/data/multiple_arabidopsis_11do/6039_1.h5
- tests/data/multiple_arabidopsis_11do/6039_1.lateral.predictions.slp
- tests/data/multiple_arabidopsis_11do/6039_1.primary.predictions.slp
- tests/data/multiple_arabidopsis_11do/7327_2.h5
- tests/data/multiple_arabidopsis_11do/7327_2.lateral.predictions.slp
- tests/data/multiple_arabidopsis_11do/7327_2.primary.predictions.slp
- tests/data/multiple_arabidopsis_11do/9535_1.h5
- tests/data/multiple_arabidopsis_11do/9535_1.lateral.predictions.slp
- tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_lateral.predictions.slp
- tests/data/multiple_arabidopsis_11do/9535_1.not_proofread_primary.predictions.slp
- tests/data/multiple_arabidopsis_11do/9535_1.primary.predictions.slp
- tests/data/multiple_arabidopsis_11do/997_1.h5
- tests/data/multiple_arabidopsis_11do/997_1.lateral.predictions.slp
- tests/data/multiple_arabidopsis_11do/997_1.primary.predictions.slp
Additional comments: 29
sleap_roots/__init__.py (2)
- 20-20: The addition of
MultipleDicotPipeline
to the exported entities aligns with the PR's objectives to enhance the package's capabilities. This change makes the new functionality accessible to users.- 26-26: Updating the package version from "0.0.6" to "0.0.7" correctly reflects the addition of new features. This version increment follows semantic versioning principles, indicating a minor update.
tests/fixtures/data.py (1)
- 94-115: The addition of fixtures for multiple Arabidopsis 11-day-old data, including paths for image stacks, primary root predictions, and lateral root predictions, is essential for testing the new functionality related to handling multiple plants. These fixtures provide the necessary paths to test data, supporting comprehensive testing of the new features.
sleap_roots/lengths.py (2)
- 13-40: The modifications to
get_max_length_pts
enhance the function's robustness by adding input validation and handling different input shapes. These changes ensure that the function can process both single instances and multiple instances of root landmarks effectively, improving the utility of the function for various use cases.- 137-158: The addition of
get_min_distance_line_to_line
expands the package's geometric analysis capabilities by enabling the calculation of the minimum distance between twoLineString
objects. This function is valuable for analyzing spatial relationships between different root structures, and the implementation includes necessary input validation for reliability.tests/test_trait_pipelines.py (3)
- 1-1: The addition of
import numpy as np
is necessary for numerical computations in the new test functiontest_multiple_dicot_pipeline
. This change supports the testing of the new functionality.- 6-6: Adding
MultipleDicotPipeline
to the list of imports is essential for testing its functionality within the new test functiontest_multiple_dicot_pipeline
. This change enables comprehensive testing of the new feature.- 140-174: The new test function
test_multiple_dicot_pipeline
is crucial for validating the functionality of theMultipleDicotPipeline
. It tests the pipeline's ability to compute traits for multiple dicot plants, ensuring the feature works as expected. This addition aligns with the PR's objectives to enhance the software's capabilities in analyzing root traits of multiple plants.tests/test_lengths.py (1)
- 150-171: The new test function
test_min_distance_line_to_line
provides comprehensive testing for theget_min_distance_line_to_line
function, covering various scenarios including non-intersecting lines, intersecting lines, parallel lines, and invalid input types. This thorough testing ensures the function's reliability and correctness.sleap_roots/series.py (3)
- 48-48: The addition of the
csv_path
attribute to theSeries
class supports the new functionality of loading expected plant counts from a CSV file. This attribute is essential for specifying the path to the CSV file containing the expected plant counts.- 133-147: The implementation of the
expected_count
property enhances theSeries
class by enabling the retrieval of expected plant counts from a CSV file. This addition is crucial for analyses that depend on the expected number of plants. The error handling for missing CSV files and unmatched series names ensures robustness.- 273-281: > 📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [251-304]
The error handling in the property getters for primary, lateral, and crown root points improves the usability of the
Series
class by providing clear feedback when labels are not available. This ensures that users are informed of missing data, enhancing the class's reliability.sleap_roots/points.py (4)
- 4-7: Import statements for
matplotlib
,Line2D
,LineString
, andnearest_points
are correctly added to support the new functionalities introduced in this file.- 294-317: The
filter_roots_with_nans
function correctly filters out roots containing NaN values. It includes input validation and handles the case where all roots contain NaN values by returning an empty array. This approach ensures robustness in data processing.- 320-361: The
filter_plants_with_unexpected_ct
function correctly filters primary and lateral roots based on an expected count. It includes comprehensive input validation and correctly handles NaN expected counts. The approach of adjusting primary and lateral roots to empty arrays when the count does not match is clear and effective.- 536-593: The
plot_root_associations
function provides a visual representation of the associations between primary and lateral roots, including the minimum distance lines. It correctly uses matplotlib for plotting and customizes the legend and color map. However, the use of a red dashed line ("r--"
) for minimum distance might conflict with the comment about ensuring the color map does not include red. This is not a critical issue but something to be aware of in terms of visual clarity.tests/test_points.py (5)
- 3-16: The import statements are correctly updated to include
LineString
and additional functions fromsleap_roots.points
. This ensures that the new functionalities introduced insleap_roots/points.py
are properly tested.- 364-382: The
test_associate_basic
function correctly tests the basic association between one primary and one lateral root. It includes comprehensive assertions to verify the structure and content of the association result. This test effectively validates the expected behavior of theassociate_lateral_to_primary
function.- 385-393: The
test_associate_no_primary
function correctly tests the scenario where there are no primary roots. It validates that an empty dictionary is returned, which is the expected behavior. This test ensures that the function handles edge cases gracefully.- 621-626: The
test_filter_roots_with_nans_no_nans
function effectively tests thefilter_roots_with_nans
function with an input array that contains no NaN values. It correctly asserts that the original array should be returned, validating the function's behavior in a scenario without NaN values.- 687-696: The
test_filter_plants_with_unexpected_ct_valid_input_matching_count
function correctly tests the scenario where the number of primary roots matches the expected count. It validates that the original primary and lateral points arrays are returned, ensuring the function behaves as expected when the count matches.sleap_roots/trait_pipelines.py (8)
- 3-3: The addition of
import json
is appropriate for the new functionality related to JSON serialization of NumPy arrays.- 118-125: The
NumpyArrayEncoder
class is a well-implemented custom JSON encoder for handling NumPy arrays. It correctly checks if the object is an instance ofnp.ndarray
and converts it to a list, which is JSON serializable.- 127-127: The
TraitDef
class is a significant addition that provides a structured way to define traits for analysis. It includes comprehensive documentation and a clear structure, enhancing maintainability and readability.- 261-268: The
csv_traits_multiple_plants
property method is a thoughtful addition for handling CSV traits specific to scenarios involving multiple plants. It ensures that only traits marked for inclusion in CSVs are processed, which is crucial for performance and correctness.- 376-481: The
compute_multiple_dicots_traits
function introduces complex logic for computing traits across multiple dicots. It's well-structured and includes error handling for file operations. However, consider adding more detailed error messages to improve debuggability.Consider enhancing the error messages in the exception handling blocks to include more context about the failure, which can aid in debugging.
- 525-568: The
compute_batch_multiple_dicots_traits
function efficiently processes a batch of series with multiple dicots. It demonstrates good use of existing functionality and maintains consistency in handling CSV output. The logging of the processing status for each series is a helpful addition for users.- 1977-2046: The
MultipleDicotPipeline
class is a crucial addition for handling multiple dicot plants. It defines traits specific to this scenario and leverages existing functions for filtering and associating roots. The structure and documentation of this class contribute to its readability and maintainability.- 2051-2071: The
get_initial_frame_traits
method in theMultipleDicotPipeline
class correctly gathers initial traits for a plant frame, including handling the expected plant count. This method is essential for initializing the trait computation process and is implemented correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review Status
Actionable comments generated: 0
Configuration used: CodeRabbit UI
Files selected for processing (1)
- sleap_roots/lengths.py (2 hunks)
Files skipped from review as they are similar to previous changes (1)
- sleap_roots/lengths.py
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #71 +/- ##
==========================================
- Coverage 80.65% 74.90% -5.76%
==========================================
Files 13 13
Lines 1003 1307 +304
==========================================
+ Hits 809 979 +170
- Misses 194 328 +134 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Review Status
Configuration used: CodeRabbit UI
Files selected for processing (8)
- .github/workflows/ci.yml (1 hunks)
- sleap_roots/bases.py (1 hunks)
- sleap_roots/points.py (2 hunks)
- sleap_roots/series.py (9 hunks)
- sleap_roots/trait_pipelines.py (7 hunks)
- tests/fixtures/data.py (1 hunks)
- tests/test_series.py (3 hunks)
- tests/test_trait_pipelines.py (2 hunks)
Files skipped from review as they are similar to previous changes (3)
- sleap_roots/series.py
- tests/fixtures/data.py
- tests/test_trait_pipelines.py
Additional comments not posted (17)
.github/workflows/ci.yml (1)
69-70: Enabling Git LFS during the repository checkout step is a crucial addition for handling large files in CI workflows. This ensures that all necessary files are properly fetched, facilitating more comprehensive testing and integration checks.
tests/test_series.py (3)
9-13: Adding a
series_instance
fixture simplifies the creation of aSeries
instance with dummy data for testing. This enhances test modularity and reusability.
50-57: The
csv_path
fixture for creating a dummy CSV file is a valuable addition for testing series properties. Ensure that the dummy CSV content aligns with the expected format and fields forSeries
instances.
79-86: Utilizing the new fixtures in tests, as seen in
test_series_name
andtest_expected_count
, improves test clarity and maintainability. These modifications ensure that the tests are more focused and easier to understand.sleap_roots/bases.py (1)
279-282: Initializing
default_dists
,default_left_bases
, anddefault_right_bases
with NaN values usingnp.full
ensures consistent handling of missing data. This modification improves the clarity and robustness of theget_root_widths
function by clearly indicating cases with no valid data.sleap_roots/points.py (4)
4-7: The addition of imports for
matplotlib
andshapely
supports the new functionality for plotting root associations and performing spatial analysis. These imports are essential for the added capabilities.
294-317: The
filter_roots_with_nans
function is a useful addition for preprocessing root data by removing instances with NaN values. This function enhances data cleanliness and reliability for subsequent analyses.
320-361: The
filter_plants_with_unexpected_ct
function provides a mechanism to filter primary and lateral roots based on an expected count, which is crucial for ensuring data consistency. This function adds robustness to the preprocessing steps.
537-596: The
plot_root_associations
function introduces visualization capabilities for root associations, enhancing the interpretability of the analysis. The use ofmatplotlib
for plotting and the careful consideration of plot aesthetics, such as the color map and axis inversion, are commendable.sleap_roots/trait_pipelines.py (8)
3-3: The addition of
import json
is appropriate for JSON serialization tasks introduced in this update.
118-135: The
NumpyArrayEncoder
class is well-implemented for custom JSON serialization of NumPy arrays andnp.int64
types. It correctly falls back to the base class method for other types.
115-140: > 📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [137-211]
The
TraitDef
class is well-defined, providing a clear structure for trait computation. It includes comprehensive documentation and a flexible design to accommodate various trait computations. Ensure that all functions referenced infn
are implemented and tested.
383-625: > 📝 NOTE
This review was outside the diff hunks and was mapped to the diff hunk with the greatest overlap. Original lines [213-699]
The
Pipeline
class and its subclasses (DicotPipeline
,MultipleDicotPipeline
) are central to the trait computation logic. They are well-structured and include methods for defining traits, computing traits for frames, and handling multiple plants. However, there are a few areas that could be improved for clarity and efficiency:
- In the
compute_multiple_dicots_traits
method, consider handling exceptions more gracefully when writing JSON or CSV files. Instead of printing errors directly, it might be beneficial to log these errors or raise exceptions to be handled by the caller.- The method
get_initial_frame_traits
inMultipleDicotPipeline
class usesexpected_plant_ct
which is not defined within the provided context. Ensure that this attribute is correctly initialized and used within theSeries
class.- The use of print statements for logging (e.g.,
print(f"Processing series '{series.series_name}'")
) is not ideal for production code. Consider using a logging framework that allows for different logging levels and better control over the output.
386-496: The
compute_multiple_dicots_traits
method is comprehensive and covers various aspects of trait computation for multiple dicots. However, consider adding error handling for file operations and replacing print statements with logging for better control and flexibility in output management.
498-605: The
compute_multiple_dicots_traits_for_groups
method aggregates traits over groups of samples efficiently. Similar to previous comments, consider enhancing error handling and replacing print statements with a more robust logging approach.
655-698: The
compute_batch_multiple_dicots_traits
method effectively summarizes traits for a batch of series with multiple dicots. Ensure that the CSV writing process is robust against potential errors, and consider using logging instead of print statements.
2106-2198: The
MultipleDicotPipeline
class introduces a specialized pipeline for computing traits for multiple dicot plants. The design and implementation are consistent with the rest of the module. Ensure that all referenced methods (e.g.,filter_roots_with_nans
,filter_plants_with_unexpected_ct
) are implemented and tested for correctness.
def associate_lateral_to_primary( | ||
primary_pts: np.ndarray, lateral_pts: np.ndarray | ||
) -> dict: | ||
"""Associates each lateral root with the closest primary root. | ||
|
||
Args: | ||
primary_pts: A numpy array of primary root points with shape | ||
(instances, nodes, 2), where 'instances' is the number of primary roots, | ||
'nodes' is the number of points in each root, and '2' corresponds to the x and y | ||
coordinates. Points cannot have NaN values. | ||
lateral_pts: A numpy array of lateral root points with a shape similar | ||
to primary_pts, representing the lateral roots. Points cannot have NaN values. | ||
|
||
Returns: | ||
dict: A dictionary where each key is an index of a primary root (from the primary_pts | ||
array) and each value is a dictionary containing 'primary_points' as the points of | ||
the primary root (1, nodes, 2) and 'lateral_points' as an array of | ||
lateral root points that are closest to that primary root. The shape of | ||
'lateral_points' is (instances, nodes, 2), where instances is the number of | ||
lateral roots associated with the primary root. | ||
""" | ||
# Basic input validation | ||
if not isinstance(primary_pts, np.ndarray) or not isinstance( | ||
lateral_pts, np.ndarray | ||
): | ||
raise ValueError("Both primary_pts and lateral_pts must be numpy arrays.") | ||
if len(primary_pts.shape) != 3 or len(lateral_pts.shape) != 3: | ||
raise ValueError("Input arrays must have a shape of (instances, nodes, 2).") | ||
if primary_pts.shape[2] != 2 or lateral_pts.shape[2] != 2: | ||
raise ValueError( | ||
"The last dimension of input arrays must be 2, representing x and y coordinates." | ||
) | ||
|
||
plant_associations = {} | ||
|
||
# Initialize plant associations dictionary | ||
for i, primary_root in enumerate(primary_pts): | ||
if not is_line_valid(primary_root): | ||
continue # Skip primary roots containing NaN values | ||
plant_associations[i] = { | ||
"primary_points": primary_root, | ||
"lateral_points": [], | ||
} | ||
|
||
# Associate each lateral root with the closest primary root | ||
for lateral_root in lateral_pts: | ||
if not is_line_valid(lateral_root): | ||
continue # Skip lateral roots containing NaN values | ||
|
||
lateral_line = LineString(lateral_root) | ||
min_distance = float("inf") | ||
closest_primary_index = None | ||
|
||
for primary_index, primary_data in plant_associations.items(): | ||
primary_root = primary_data["primary_points"] | ||
try: | ||
primary_line = LineString(primary_root) | ||
distance = primary_line.distance(lateral_line) | ||
except Exception as e: | ||
print(f"Error computing distance: {e}") | ||
continue | ||
|
||
if distance < min_distance: | ||
min_distance = distance | ||
closest_primary_index = primary_index | ||
|
||
if closest_primary_index is not None: | ||
plant_associations[closest_primary_index]["lateral_points"].append( | ||
lateral_root | ||
) | ||
|
||
# Convert lateral points lists into arrays | ||
for primary_index, data in plant_associations.items(): | ||
lateral_points_list = data["lateral_points"] | ||
if lateral_points_list: # Check if there are any lateral points to convert | ||
lateral_points_array = np.array(lateral_points_list) | ||
plant_associations[primary_index]["lateral_points"] = lateral_points_array | ||
else: | ||
# Create an array of NaNs if there are no lateral points | ||
shape = (1, lateral_pts.shape[1], 2) # Shape of lateral points array | ||
plant_associations[primary_index]["lateral_points"] = np.full(shape, np.nan) | ||
|
||
return plant_associations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The associate_lateral_to_primary
function effectively associates lateral roots with the closest primary root. It uses LineString
for representing roots and nearest_points
for finding the closest points, which is a robust approach for spatial analysis. However, consider catching specific exceptions instead of a general Exception
to improve error handling specificity.
- except Exception as e:
+ except ValueError as e: # Adjust to the specific exception type if different
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def associate_lateral_to_primary( | |
primary_pts: np.ndarray, lateral_pts: np.ndarray | |
) -> dict: | |
"""Associates each lateral root with the closest primary root. | |
Args: | |
primary_pts: A numpy array of primary root points with shape | |
(instances, nodes, 2), where 'instances' is the number of primary roots, | |
'nodes' is the number of points in each root, and '2' corresponds to the x and y | |
coordinates. Points cannot have NaN values. | |
lateral_pts: A numpy array of lateral root points with a shape similar | |
to primary_pts, representing the lateral roots. Points cannot have NaN values. | |
Returns: | |
dict: A dictionary where each key is an index of a primary root (from the primary_pts | |
array) and each value is a dictionary containing 'primary_points' as the points of | |
the primary root (1, nodes, 2) and 'lateral_points' as an array of | |
lateral root points that are closest to that primary root. The shape of | |
'lateral_points' is (instances, nodes, 2), where instances is the number of | |
lateral roots associated with the primary root. | |
""" | |
# Basic input validation | |
if not isinstance(primary_pts, np.ndarray) or not isinstance( | |
lateral_pts, np.ndarray | |
): | |
raise ValueError("Both primary_pts and lateral_pts must be numpy arrays.") | |
if len(primary_pts.shape) != 3 or len(lateral_pts.shape) != 3: | |
raise ValueError("Input arrays must have a shape of (instances, nodes, 2).") | |
if primary_pts.shape[2] != 2 or lateral_pts.shape[2] != 2: | |
raise ValueError( | |
"The last dimension of input arrays must be 2, representing x and y coordinates." | |
) | |
plant_associations = {} | |
# Initialize plant associations dictionary | |
for i, primary_root in enumerate(primary_pts): | |
if not is_line_valid(primary_root): | |
continue # Skip primary roots containing NaN values | |
plant_associations[i] = { | |
"primary_points": primary_root, | |
"lateral_points": [], | |
} | |
# Associate each lateral root with the closest primary root | |
for lateral_root in lateral_pts: | |
if not is_line_valid(lateral_root): | |
continue # Skip lateral roots containing NaN values | |
lateral_line = LineString(lateral_root) | |
min_distance = float("inf") | |
closest_primary_index = None | |
for primary_index, primary_data in plant_associations.items(): | |
primary_root = primary_data["primary_points"] | |
try: | |
primary_line = LineString(primary_root) | |
distance = primary_line.distance(lateral_line) | |
except Exception as e: | |
print(f"Error computing distance: {e}") | |
continue | |
if distance < min_distance: | |
min_distance = distance | |
closest_primary_index = primary_index | |
if closest_primary_index is not None: | |
plant_associations[closest_primary_index]["lateral_points"].append( | |
lateral_root | |
) | |
# Convert lateral points lists into arrays | |
for primary_index, data in plant_associations.items(): | |
lateral_points_list = data["lateral_points"] | |
if lateral_points_list: # Check if there are any lateral points to convert | |
lateral_points_array = np.array(lateral_points_list) | |
plant_associations[primary_index]["lateral_points"] = lateral_points_array | |
else: | |
# Create an array of NaNs if there are no lateral points | |
shape = (1, lateral_pts.shape[1], 2) # Shape of lateral points array | |
plant_associations[primary_index]["lateral_points"] = np.full(shape, np.nan) | |
return plant_associations | |
except ValueError as e: # Adjust to the specific exception type if different |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Review Status
Configuration used: CodeRabbit UI
Files selected for processing (5)
- MultiDicotPipeline.ipynb (1 hunks)
- sleap_roots/bases.py (2 hunks)
- sleap_roots/trait_pipelines.py (7 hunks)
- tests/test_bases.py (2 hunks)
- tests/test_trait_pipelines.py (2 hunks)
Files skipped from review as they are similar to previous changes (2)
- sleap_roots/bases.py
- tests/test_trait_pipelines.py
Additional Context Used
GitHub Check Runs (1)
codecov/patch success (30)
sleap_roots/trait_pipelines.py: [warning] 130-133: sleap_roots/trait_pipelines.py#L130-L133
Added lines #L130 - L133 were not covered by tests
sleap_roots/trait_pipelines.py: [warning] 135-135: sleap_roots/trait_pipelines.py#L135
Added line #L135 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 416-416: sleap_roots/trait_pipelines.py#L416
Added line #L416 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 418-418: sleap_roots/trait_pipelines.py#L418
Added line #L418 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 467-470: sleap_roots/trait_pipelines.py#L467-L470
Added lines #L467 - L470 were not covered by tests
Additional comments not posted (11)
tests/test_bases.py (4)
379-384
: The use ofnp.nan
in the expected results fortest_get_root_widths
withreturn_inds=True
is correct for handling cases where the primary or lateral roots are not present or cannot be calculated due to missing data. However, it's important to ensure that the handling ofnp.nan
values is consistent across all related functions and that any downstream processing of these results properly accounts fornp.nan
values to avoid unexpected behavior.
429-439
: The test case forget_root_widths
with minimum length inputs correctly asserts that the result should benp.nan
. This is a good practice for handling edge cases where the input data does not meet the minimum requirements for meaningful computation. It's crucial to document these edge cases and the expected behavior in both the test and the function being tested to ensure clarity for future maintenance.
390-395
: In the parameterized test forget_root_widths
, handling ofnp.nan
values is consistent with the expected behavior when primary or lateral roots are missing or invalid. This approach ensures robustness in the face of incomplete or corrupted data. It would be beneficial to include a brief comment explaining the rationale behind each test case scenario for future reference and clarity.
429-446
: The handling of invalid cases intest_get_root_widths_invalid_cases
through explicitValueError
exceptions is a good practice for early detection of incorrect usage patterns. It's important to ensure that these error messages are descriptive enough to aid in debugging and that similar validation is performed in the main function to prevent misuse.MultiDicotPipeline.ipynb (4)
26-29
: The paths for CSV and folder are hardcoded, which might limit the notebook's usability in different environments or datasets. Consider using a configuration file or environment variables to make these paths configurable. This approach enhances the notebook's flexibility and reusability across different datasets or project structures.
117-118
: Initializing theMultipleDicotPipeline
without any configuration parameters is straightforward for this example. However, if the pipeline supports customization (e.g., different processing strategies, parameters), it would be beneficial to demonstrate how to configure these options or link to documentation for further details. This information aids users in understanding how to adapt the pipeline to their specific needs.
136-137
: When computing traits for the first sample and writing the results to JSON and CSV, it's good practice to provide feedback to the user about the operation's success and the location of the output files. This feedback can be particularly helpful in a notebook environment where users might not immediately notice file generation. Consider adding print statements or logging to inform the user.
668-669
: The functioncompute_batch_multiple_dicots_traits_for_groups
is called withwrite_json=True
andwrite_csv=True
, which is consistent with the intent to save the computed traits. However, it's unclear where these files are saved. Providing explicit paths or confirming the output directory would improve usability and clarity for users trying to locate the generated files.sleap_roots/trait_pipelines.py (3)
3-6
: The added importsjson
andfrom typing import Union
are correctly used in the context of the new functionality introduced in this file.
137-137
: TheTraitDef
class is well-defined and documented, providing a clear structure for trait computation. Good use ofattrs
for concise attribute definitions.
137-137
: ThePipeline
class is correctly implemented with smart use ofnetworkx
for determining the computation order of traits. Consider adding more detailed documentation for methods involving complex logic, such asget_computation_order
, to enhance maintainability.
except Exception as e: | ||
raise IOError(f"Failed to write computed traits to CSV: {e}") | ||
|
||
return all_group_summaries_df |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The DicotPipeline
class is well-implemented, providing specific trait computations for dicot plants. Given the complexity of these computations, ensure that unit tests cover this new functionality comprehensively.
+ # TODO: Add comprehensive unit tests for DicotPipeline trait computations
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
return all_group_summaries_df | |
# TODO: Add comprehensive unit tests for DicotPipeline trait computations | |
return all_group_summaries_df |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Review Status
Configuration used: CodeRabbit UI
Files selected for processing (1)
- sleap_roots/trait_pipelines.py (7 hunks)
Additional Context Used
GitHub Check Runs (1)
codecov/patch success (30)
sleap_roots/trait_pipelines.py: [warning] 130-133: sleap_roots/trait_pipelines.py#L130-L133
Added lines #L130 - L133 were not covered by tests
sleap_roots/trait_pipelines.py: [warning] 135-135: sleap_roots/trait_pipelines.py#L135
Added line #L135 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 416-416: sleap_roots/trait_pipelines.py#L416
Added line #L416 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 418-418: sleap_roots/trait_pipelines.py#L418
Added line #L418 was not covered by tests
sleap_roots/trait_pipelines.py: [warning] 467-470: sleap_roots/trait_pipelines.py#L467-L470
Added lines #L467 - L470 were not covered by tests
class NumpyArrayEncoder(json.JSONEncoder): | ||
"""Custom encoder for NumPy array types.""" | ||
|
||
def default(self, obj): | ||
"""Serialize NumPy arrays to lists. | ||
|
||
Args: | ||
obj: The object to serialize. | ||
|
||
Returns: | ||
A list representation of the NumPy array. | ||
""" | ||
if isinstance(obj, np.ndarray): | ||
return obj.tolist() | ||
elif isinstance(obj, np.int64): | ||
return int(obj) | ||
# Let the base class default method raise the TypeError | ||
return json.JSONEncoder.default(self, obj) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The NumpyArrayEncoder
class is correctly implemented to handle NumPy types during JSON serialization. However, some lines in this class were not covered by tests. It's important to ensure comprehensive test coverage for this functionality.
+ # TODO: Add tests to cover the NumpyArrayEncoder functionality
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
class NumpyArrayEncoder(json.JSONEncoder): | |
"""Custom encoder for NumPy array types.""" | |
def default(self, obj): | |
"""Serialize NumPy arrays to lists. | |
Args: | |
obj: The object to serialize. | |
Returns: | |
A list representation of the NumPy array. | |
""" | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
elif isinstance(obj, np.int64): | |
return int(obj) | |
# Let the base class default method raise the TypeError | |
return json.JSONEncoder.default(self, obj) | |
class NumpyArrayEncoder(json.JSONEncoder): | |
"""Custom encoder for NumPy array types.""" | |
# TODO: Add tests to cover the NumpyArrayEncoder functionality | |
def default(self, obj): | |
"""Serialize NumPy arrays to lists. | |
Args: | |
obj: The object to serialize. | |
Returns: | |
A list representation of the NumPy array. | |
""" | |
if isinstance(obj, np.ndarray): | |
return obj.tolist() | |
elif isinstance(obj, np.int64): | |
return int(obj) | |
# Let the base class default method raise the TypeError | |
return json.JSONEncoder.default(self, obj) |
series: The Series object containing the primary and lateral root points. | ||
write_json: Whether to write the aggregated traits to a JSON file. Default is False. | ||
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json". | ||
write_csv: Whether to write the summary statistics to a CSV file. Default is False. | ||
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv". | ||
|
||
Returns: | ||
A dictionary containing the series name, group, aggregated traits, and summary statistics. | ||
""" | ||
# Initialize the return structure with the series name and group | ||
result = { | ||
"series": str(series.series_name), | ||
"group": str(series.group), | ||
"traits": {}, | ||
"summary_stats": {}, | ||
} | ||
|
||
# Check if the series has frames to process | ||
if len(series) == 0: | ||
print(f"Series '{series.series_name}' contains no frames to process.") | ||
# Return early with the initialized structure | ||
return result | ||
|
||
# Initialize a separate dictionary to hold the aggregated traits across all frames | ||
aggregated_traits = {} | ||
|
||
# Iterate over frames in series | ||
for frame in range(len(series)): | ||
# Get initial points and number of plants per frame | ||
initial_frame_traits = self.get_initial_frame_traits(series, frame) | ||
# Compute initial associations and perform filter operations | ||
frame_traits = self.compute_frame_traits(initial_frame_traits) | ||
|
||
# Instantiate DicotPipeline | ||
dicot_pipeline = DicotPipeline() | ||
|
||
# Extract the plant associations for this frame | ||
associations = frame_traits["plant_associations_dict"] | ||
|
||
for primary_idx, assoc in associations.items(): | ||
primary_pts = assoc["primary_points"] | ||
lateral_pts = assoc["lateral_points"] | ||
# Get the initial frame traits for this plant using the primary and lateral points | ||
initial_frame_traits = { | ||
"primary_pts": primary_pts, | ||
"lateral_pts": lateral_pts, | ||
} | ||
# Use the dicot pipeline to compute the plant traits on this frame | ||
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits) | ||
|
||
# For each plant's traits in the frame | ||
for trait_name, trait_value in plant_traits.items(): | ||
# Not all traits are added to the aggregated traits dictionary | ||
if trait_name in dicot_pipeline.csv_traits_multiple_plants: | ||
if trait_name not in aggregated_traits: | ||
# Initialize the trait array if it's the first frame | ||
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)] | ||
else: | ||
# Append new trait values for subsequent frames | ||
aggregated_traits[trait_name].append( | ||
np.atleast_1d(trait_value) | ||
) | ||
|
||
# After processing, update the result dictionary with computed traits | ||
for trait, arrays in aggregated_traits.items(): | ||
aggregated_traits[trait] = np.concatenate(arrays, axis=0) | ||
result["traits"] = aggregated_traits | ||
|
||
# Write to JSON if requested | ||
if write_json: | ||
json_name = f"{series.series_name}{json_suffix}" | ||
try: | ||
with open(json_name, "w") as f: | ||
json.dump( | ||
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4 | ||
) | ||
print(f"Aggregated traits saved to {json_name}") | ||
except IOError as e: | ||
print(f"Error writing JSON file '{json_name}': {e}") | ||
|
||
# Compute summary statistics and update result | ||
summary_stats = {} | ||
for trait_name, trait_values in aggregated_traits.items(): | ||
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_") | ||
summary_stats.update(trait_stats) | ||
result["summary_stats"] = summary_stats | ||
|
||
# Optionally write summary stats to CSV | ||
if write_csv: | ||
csv_name = f"{series.series_name}{csv_suffix}" | ||
try: | ||
summary_df = pd.DataFrame([summary_stats]) | ||
summary_df.insert(0, "series", series.series_name) | ||
summary_df.to_csv(csv_name, index=False) | ||
print(f"Summary statistics saved to {csv_name}") | ||
except IOError as e: | ||
print(f"Failed to write CSV file '{csv_name}': {e}") | ||
|
||
# Return the final result structure | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The compute_multiple_dicots_traits
method is well-implemented, but some lines were not covered by tests. Ensuring comprehensive test coverage is crucial for verifying the correctness of this functionality.
+ # TODO: Add tests to cover the compute_multiple_dicots_traits functionality
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
def compute_multiple_dicots_traits( | |
self, | |
series: Series, | |
write_json: bool = False, | |
json_suffix: str = ".all_frames_traits.json", | |
write_csv: bool = False, | |
csv_suffix: str = ".all_frames_summary.csv", | |
): | |
"""Computes plant traits for pipelines with multiple plants over all frames in a series. | |
Args: | |
series: The Series object containing the primary and lateral root points. | |
write_json: Whether to write the aggregated traits to a JSON file. Default is False. | |
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json". | |
write_csv: Whether to write the summary statistics to a CSV file. Default is False. | |
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv". | |
Returns: | |
A dictionary containing the series name, group, aggregated traits, and summary statistics. | |
""" | |
# Initialize the return structure with the series name and group | |
result = { | |
"series": str(series.series_name), | |
"group": str(series.group), | |
"traits": {}, | |
"summary_stats": {}, | |
} | |
# Check if the series has frames to process | |
if len(series) == 0: | |
print(f"Series '{series.series_name}' contains no frames to process.") | |
# Return early with the initialized structure | |
return result | |
# Initialize a separate dictionary to hold the aggregated traits across all frames | |
aggregated_traits = {} | |
# Iterate over frames in series | |
for frame in range(len(series)): | |
# Get initial points and number of plants per frame | |
initial_frame_traits = self.get_initial_frame_traits(series, frame) | |
# Compute initial associations and perform filter operations | |
frame_traits = self.compute_frame_traits(initial_frame_traits) | |
# Instantiate DicotPipeline | |
dicot_pipeline = DicotPipeline() | |
# Extract the plant associations for this frame | |
associations = frame_traits["plant_associations_dict"] | |
for primary_idx, assoc in associations.items(): | |
primary_pts = assoc["primary_points"] | |
lateral_pts = assoc["lateral_points"] | |
# Get the initial frame traits for this plant using the primary and lateral points | |
initial_frame_traits = { | |
"primary_pts": primary_pts, | |
"lateral_pts": lateral_pts, | |
} | |
# Use the dicot pipeline to compute the plant traits on this frame | |
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits) | |
# For each plant's traits in the frame | |
for trait_name, trait_value in plant_traits.items(): | |
# Not all traits are added to the aggregated traits dictionary | |
if trait_name in dicot_pipeline.csv_traits_multiple_plants: | |
if trait_name not in aggregated_traits: | |
# Initialize the trait array if it's the first frame | |
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)] | |
else: | |
# Append new trait values for subsequent frames | |
aggregated_traits[trait_name].append( | |
np.atleast_1d(trait_value) | |
) | |
# After processing, update the result dictionary with computed traits | |
for trait, arrays in aggregated_traits.items(): | |
aggregated_traits[trait] = np.concatenate(arrays, axis=0) | |
result["traits"] = aggregated_traits | |
# Write to JSON if requested | |
if write_json: | |
json_name = f"{series.series_name}{json_suffix}" | |
try: | |
with open(json_name, "w") as f: | |
json.dump( | |
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4 | |
) | |
print(f"Aggregated traits saved to {json_name}") | |
except IOError as e: | |
print(f"Error writing JSON file '{json_name}': {e}") | |
# Compute summary statistics and update result | |
summary_stats = {} | |
for trait_name, trait_values in aggregated_traits.items(): | |
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_") | |
summary_stats.update(trait_stats) | |
result["summary_stats"] = summary_stats | |
# Optionally write summary stats to CSV | |
if write_csv: | |
csv_name = f"{series.series_name}{csv_suffix}" | |
try: | |
summary_df = pd.DataFrame([summary_stats]) | |
summary_df.insert(0, "series", series.series_name) | |
summary_df.to_csv(csv_name, index=False) | |
print(f"Summary statistics saved to {csv_name}") | |
except IOError as e: | |
print(f"Failed to write CSV file '{csv_name}': {e}") | |
# Return the final result structure | |
return result | |
def compute_multiple_dicots_traits( | |
self, | |
series: Series, | |
write_json: bool = False, | |
json_suffix: str = ".all_frames_traits.json", | |
write_csv: bool = False, | |
csv_suffix: str = ".all_frames_summary.csv", | |
): | |
"""Computes plant traits for pipelines with multiple plants over all frames in a series. | |
Args: | |
series: The Series object containing the primary and lateral root points. | |
write_json: Whether to write the aggregated traits to a JSON file. Default is False. | |
json_suffix: The suffix to append to the JSON file name. Default is ".all_frames_traits.json". | |
write_csv: Whether to write the summary statistics to a CSV file. Default is False. | |
csv_suffix: The suffix to append to the CSV file name. Default is ".all_frames_summary.csv". | |
Returns: | |
A dictionary containing the series name, group, aggregated traits, and summary statistics. | |
""" | |
# Initialize the return structure with the series name and group | |
result = { | |
"series": str(series.series_name), | |
"group": str(series.group), | |
"traits": {}, | |
"summary_stats": {}, | |
} | |
# Check if the series has frames to process | |
if len(series) == 0: | |
print(f"Series '{series.series_name}' contains no frames to process.") | |
# Return early with the initialized structure | |
return result | |
# Initialize a separate dictionary to hold the aggregated traits across all frames | |
aggregated_traits = {} | |
# Iterate over frames in series | |
for frame in range(len(series)): | |
# Get initial points and number of plants per frame | |
initial_frame_traits = self.get_initial_frame_traits(series, frame) | |
# Compute initial associations and perform filter operations | |
frame_traits = self.compute_frame_traits(initial_frame_traits) | |
# Instantiate DicotPipeline | |
dicot_pipeline = DicotPipeline() | |
# Extract the plant associations for this frame | |
associations = frame_traits["plant_associations_dict"] | |
for primary_idx, assoc in associations.items(): | |
primary_pts = assoc["primary_points"] | |
lateral_pts = assoc["lateral_points"] | |
# Get the initial frame traits for this plant using the primary and lateral points | |
initial_frame_traits = { | |
"primary_pts": primary_pts, | |
"lateral_pts": lateral_pts, | |
} | |
# Use the dicot pipeline to compute the plant traits on this frame | |
plant_traits = dicot_pipeline.compute_frame_traits(initial_frame_traits) | |
# For each plant's traits in the frame | |
for trait_name, trait_value in plant_traits.items(): | |
# Not all traits are added to the aggregated traits dictionary | |
if trait_name in dicot_pipeline.csv_traits_multiple_plants: | |
if trait_name not in aggregated_traits: | |
# Initialize the trait array if it's the first frame | |
aggregated_traits[trait_name] = [np.atleast_1d(trait_value)] | |
else: | |
# Append new trait values for subsequent frames | |
aggregated_traits[trait_name].append( | |
np.atleast_1d(trait_value) | |
) | |
# After processing, update the result dictionary with computed traits | |
for trait, arrays in aggregated_traits.items(): | |
aggregated_traits[trait] = np.concatenate(arrays, axis=0) | |
result["traits"] = aggregated_traits | |
# Write to JSON if requested | |
if write_json: | |
json_name = f"{series.series_name}{json_suffix}" | |
try: | |
with open(json_name, "w") as f: | |
json.dump( | |
result, f, cls=NumpyArrayEncoder, ensure_ascii=False, indent=4 | |
) | |
print(f"Aggregated traits saved to {json_name}") | |
except IOError as e: | |
print(f"Error writing JSON file '{json_name}': {e}") | |
# Compute summary statistics and update result | |
summary_stats = {} | |
for trait_name, trait_values in aggregated_traits.items(): | |
trait_stats = get_summary(trait_values, prefix=f"{trait_name}_") | |
summary_stats.update(trait_stats) | |
result["summary_stats"] = summary_stats | |
# Optionally write summary stats to CSV | |
if write_csv: | |
csv_name = f"{series.series_name}{csv_suffix}" | |
try: | |
summary_df = pd.DataFrame([summary_stats]) | |
summary_df.insert(0, "series", series.series_name) | |
summary_df.to_csv(csv_name, index=False) | |
print(f"Summary statistics saved to {csv_name}") | |
except IOError as e: | |
print(f"Failed to write CSV file '{csv_name}': {e}") | |
# Return the final result structure | |
return result | |
# TODO: Add tests to cover the compute_multiple_dicots_traits functionality |
@attrs.define | ||
class MultipleDicotPipeline(Pipeline): | ||
"""Pipeline for computing traits for multiple dicot plants.""" | ||
|
||
def define_traits(self) -> List[TraitDef]: | ||
"""Define the trait computation pipeline for primary roots.""" | ||
trait_definitions = [ | ||
TraitDef( | ||
name="primary_pts_no_nans", | ||
fn=filter_roots_with_nans, | ||
input_traits=["primary_pts"], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Primary roots without any NaNs.", | ||
), | ||
TraitDef( | ||
name="lateral_pts_no_nans", | ||
fn=filter_roots_with_nans, | ||
input_traits=["lateral_pts"], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Lateral roots without any NaNs.", | ||
), | ||
TraitDef( | ||
name="filtered_pts_expected_plant_ct", | ||
fn=filter_plants_with_unexpected_ct, | ||
input_traits=[ | ||
"primary_pts_no_nans", | ||
"lateral_pts_no_nans", | ||
"expected_plant_ct", | ||
], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Tuple of filtered points with expected plant count.", | ||
), | ||
TraitDef( | ||
name="primary_pts_expected_plant_ct", | ||
fn=get_filtered_primary_pts, | ||
input_traits=["filtered_pts_expected_plant_ct"], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Filtered primary root points with expected plant count.", | ||
), | ||
TraitDef( | ||
name="lateral_pts_expected_plant_ct", | ||
fn=get_filtered_lateral_pts, | ||
input_traits=["filtered_pts_expected_plant_ct"], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Filtered lateral root points with expected plant count.", | ||
), | ||
TraitDef( | ||
name="plant_associations_dict", | ||
fn=associate_lateral_to_primary, | ||
input_traits=[ | ||
"primary_pts_expected_plant_ct", | ||
"lateral_pts_expected_plant_ct", | ||
], | ||
scalar=False, | ||
include_in_csv=False, | ||
kwargs={}, | ||
description="Dictionary of plant associations.", | ||
), | ||
] | ||
|
||
return trait_definitions | ||
|
||
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: | ||
"""Return initial traits for a plant frame. | ||
|
||
Args: | ||
plant: The plant `Series` object. | ||
frame_idx: The index of the current frame. | ||
|
||
Returns: | ||
A dictionary of initial traits with keys: | ||
- "primary_pts": Array of primary root points. | ||
- "lateral_pts": Array of lateral root points. | ||
- "expected_ct": Expected number of plants as a float. | ||
""" | ||
primary_pts = plant.get_primary_points(frame_idx) | ||
lateral_pts = plant.get_lateral_points(frame_idx) | ||
expected_plant_ct = plant.expected_count | ||
return { | ||
"primary_pts": primary_pts, | ||
"lateral_pts": lateral_pts, | ||
"expected_plant_ct": expected_plant_ct, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The MultipleDicotPipeline
class introduces important functionality for handling multiple dicot plants, including filtering and association logic. Ensure that this functionality is thoroughly tested, given its critical role in the pipeline.
+ # TODO: Ensure comprehensive testing for filtering and association logic in MultipleDicotPipeline
Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation.
@attrs.define | |
class MultipleDicotPipeline(Pipeline): | |
"""Pipeline for computing traits for multiple dicot plants.""" | |
def define_traits(self) -> List[TraitDef]: | |
"""Define the trait computation pipeline for primary roots.""" | |
trait_definitions = [ | |
TraitDef( | |
name="primary_pts_no_nans", | |
fn=filter_roots_with_nans, | |
input_traits=["primary_pts"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Primary roots without any NaNs.", | |
), | |
TraitDef( | |
name="lateral_pts_no_nans", | |
fn=filter_roots_with_nans, | |
input_traits=["lateral_pts"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Lateral roots without any NaNs.", | |
), | |
TraitDef( | |
name="filtered_pts_expected_plant_ct", | |
fn=filter_plants_with_unexpected_ct, | |
input_traits=[ | |
"primary_pts_no_nans", | |
"lateral_pts_no_nans", | |
"expected_plant_ct", | |
], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Tuple of filtered points with expected plant count.", | |
), | |
TraitDef( | |
name="primary_pts_expected_plant_ct", | |
fn=get_filtered_primary_pts, | |
input_traits=["filtered_pts_expected_plant_ct"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Filtered primary root points with expected plant count.", | |
), | |
TraitDef( | |
name="lateral_pts_expected_plant_ct", | |
fn=get_filtered_lateral_pts, | |
input_traits=["filtered_pts_expected_plant_ct"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Filtered lateral root points with expected plant count.", | |
), | |
TraitDef( | |
name="plant_associations_dict", | |
fn=associate_lateral_to_primary, | |
input_traits=[ | |
"primary_pts_expected_plant_ct", | |
"lateral_pts_expected_plant_ct", | |
], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Dictionary of plant associations.", | |
), | |
] | |
return trait_definitions | |
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: | |
"""Return initial traits for a plant frame. | |
Args: | |
plant: The plant `Series` object. | |
frame_idx: The index of the current frame. | |
Returns: | |
A dictionary of initial traits with keys: | |
- "primary_pts": Array of primary root points. | |
- "lateral_pts": Array of lateral root points. | |
- "expected_ct": Expected number of plants as a float. | |
""" | |
primary_pts = plant.get_primary_points(frame_idx) | |
lateral_pts = plant.get_lateral_points(frame_idx) | |
expected_plant_ct = plant.expected_count | |
return { | |
"primary_pts": primary_pts, | |
"lateral_pts": lateral_pts, | |
"expected_plant_ct": expected_plant_ct, | |
} | |
@attrs.define | |
class MultipleDicotPipeline(Pipeline): | |
"""Pipeline for computing traits for multiple dicot plants.""" | |
# TODO: Ensure comprehensive testing for filtering and association logic in MultipleDicotPipeline | |
def define_traits(self) -> List[TraitDef]: | |
"""Define the trait computation pipeline for primary roots.""" | |
trait_definitions = [ | |
TraitDef( | |
name="primary_pts_no_nans", | |
fn=filter_roots_with_nans, | |
input_traits=["primary_pts"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Primary roots without any NaNs.", | |
), | |
TraitDef( | |
name="lateral_pts_no_nans", | |
fn=filter_roots_with_nans, | |
input_traits=["lateral_pts"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Lateral roots without any NaNs.", | |
), | |
TraitDef( | |
name="filtered_pts_expected_plant_ct", | |
fn=filter_plants_with_unexpected_ct, | |
input_traits=[ | |
"primary_pts_no_nans", | |
"lateral_pts_no_nans", | |
"expected_plant_ct", | |
], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Tuple of filtered points with expected plant count.", | |
), | |
TraitDef( | |
name="primary_pts_expected_plant_ct", | |
fn=get_filtered_primary_pts, | |
input_traits=["filtered_pts_expected_plant_ct"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Filtered primary root points with expected plant count.", | |
), | |
TraitDef( | |
name="lateral_pts_expected_plant_ct", | |
fn=get_filtered_lateral_pts, | |
input_traits=["filtered_pts_expected_plant_ct"], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Filtered lateral root points with expected plant count.", | |
), | |
TraitDef( | |
name="plant_associations_dict", | |
fn=associate_lateral_to_primary, | |
input_traits=[ | |
"primary_pts_expected_plant_ct", | |
"lateral_pts_expected_plant_ct", | |
], | |
scalar=False, | |
include_in_csv=False, | |
kwargs={}, | |
description="Dictionary of plant associations.", | |
), | |
] | |
return trait_definitions | |
def get_initial_frame_traits(self, plant: Series, frame_idx: int) -> Dict[str, Any]: | |
"""Return initial traits for a plant frame. | |
Args: | |
plant: The plant `Series` object. | |
frame_idx: The index of the current frame. | |
Returns: | |
A dictionary of initial traits with keys: | |
- "primary_pts": Array of primary root points. | |
- "lateral_pts": Array of lateral root points. | |
- "expected_ct": Expected number of plants as a float. | |
""" | |
primary_pts = plant.get_primary_points(frame_idx) | |
lateral_pts = plant.get_lateral_points(frame_idx) | |
expected_plant_ct = plant.expected_count | |
return { | |
"primary_pts": primary_pts, | |
"lateral_pts": lateral_pts, | |
"expected_plant_ct": expected_plant_ct, | |
} |
MultipleDicotPipeline
addedSeries
class altered to take in expected number of plants per seriesDicotPipeline
for multiple plants in a series createdSummary by CodeRabbit
MultipleDicotPipeline
for enhanced pipeline functionality.Series
class functionality to include loading expected plant count from a CSV file.get_max_length_pts
to handle empty arrays and incorrect shapes.