diff --git a/pyrecest/evaluation/determine_all_deviations.py b/pyrecest/evaluation/determine_all_deviations.py index dbc431a7..9c50d734 100644 --- a/pyrecest/evaluation/determine_all_deviations.py +++ b/pyrecest/evaluation/determine_all_deviations.py @@ -1,6 +1,7 @@ -import numpy as np import warnings +import numpy as np + def determine_all_deviations( results, @@ -23,7 +24,9 @@ def determine_all_deviations( if "last_filter_states" not in result_curr_config: final_estimate = result_curr_config["last_estimates"][run] elif callable(extract_mean): - final_estimate = extract_mean(result_curr_config["last_filter_states"][run]) + final_estimate = extract_mean( + result_curr_config["last_filter_states"][run] + ) else: raise ValueError("No compatible mean extraction function given.") diff --git a/pyrecest/tests/test_evaluation.py b/pyrecest/tests/test_evaluation.py index 45c6043a..70603619 100644 --- a/pyrecest/tests/test_evaluation.py +++ b/pyrecest/tests/test_evaluation.py @@ -59,7 +59,7 @@ def test_generate_meas_R2(self): def test_determine_all_deviations(self): def dummy_extract_mean(x): return x - + def dummy_distance_function(x, y): return np.linalg.norm(x - y) @@ -91,11 +91,13 @@ def dummy_distance_function(x, y): # Validate some of the results np.testing.assert_allclose( # Should be zeros as the lastEstimates match groundtruths - all_deviations[0], [0, 0] + all_deviations[0], + [0, 0], ) np.testing.assert_allclose( # Should be np.sqrt(2) away from groundtruths - all_deviations[1], [np.sqrt(3), np.sqrt(3)] + all_deviations[1], + [np.sqrt(3), np.sqrt(3)], ) def test_configure_kf(self):