diff --git a/pyrecest/evaluation/__init__.py b/pyrecest/evaluation/__init__.py index 16e7473a..d8489c00 100644 --- a/pyrecest/evaluation/__init__.py +++ b/pyrecest/evaluation/__init__.py @@ -1,5 +1,6 @@ from .check_and_fix_params import check_and_fix_params from .configure_for_filter import configure_for_filter +from .determine_all_deviations import determine_all_deviations from .generate_groundtruth import generate_groundtruth from .generate_measurements import generate_measurements from .get_axis_label import get_axis_label @@ -18,6 +19,7 @@ "configure_for_filter", "perform_predict_update_cycles", "iterate_configs_and_runs", + "determine_all_deviations", "start_evaluation", "get_axis_label", "get_distance_function", diff --git a/pyrecest/evaluation/determine_all_deviations.py b/pyrecest/evaluation/determine_all_deviations.py new file mode 100644 index 00000000..9c50d734 --- /dev/null +++ b/pyrecest/evaluation/determine_all_deviations.py @@ -0,0 +1,48 @@ +import warnings + +import numpy as np + + +def determine_all_deviations( + results, + extract_mean, + distance_function, + groundtruths, + mean_calculation_symm="", +): + if mean_calculation_symm != "": + raise NotImplementedError("Not implemented yet") + + assert ( + np.ndim(groundtruths) == 3 + ), "Assuming groundtruths to be a 3D array of shape (n_runs, n_timesteps, state_dimension)" + + all_deviations_last_mat = np.empty((len(results), groundtruths.shape[0])) + + for config, result_curr_config in enumerate(results): + for run in range(len(groundtruths)): + if "last_filter_states" not in result_curr_config: + final_estimate = result_curr_config["last_estimates"][run] + elif callable(extract_mean): + final_estimate = extract_mean( + result_curr_config["last_filter_states"][run] + ) + else: + raise ValueError("No compatible mean extraction function given.") + + if final_estimate is not None: + all_deviations_last_mat[config][run] = distance_function( + final_estimate, groundtruths[run, -1, :] + ) + else: + warnings.warn("No estimate for this filter, setting error to inf.") + all_deviations_last_mat[config][run] = np.inf + + if np.any(np.isinf(all_deviations_last_mat[config])): + print( + f"Warning: {result_curr_config['filterName']} with {result_curr_config['filterParams']} " + f"parameters apparently failed {np.sum(np.isinf(all_deviations_last_mat[config]))} " + "times. Check if this is plausible." + ) + + return all_deviations_last_mat diff --git a/pyrecest/tests/test_evaluation.py b/pyrecest/tests/test_evaluation.py index 92331c95..70603619 100644 --- a/pyrecest/tests/test_evaluation.py +++ b/pyrecest/tests/test_evaluation.py @@ -10,6 +10,7 @@ from pyrecest.evaluation import ( check_and_fix_params, configure_for_filter, + determine_all_deviations, generate_groundtruth, generate_measurements, get_axis_label, @@ -55,6 +56,50 @@ def test_generate_meas_R2(self): self.scenario_param["n_meas_at_individual_time_step"][i], ) + def test_determine_all_deviations(self): + def dummy_extract_mean(x): + return x + + def dummy_distance_function(x, y): + return np.linalg.norm(x - y) + + groundtruths = np.array([[[1, 2, 3], [2, 3, 4]], [[11, 12, 13], [12, 13, 14]]]) + results = [ + { + "filter_name": "filter1", + "filter_param": "params1", + "last_estimates": groundtruths[:, -1, :], + }, + { + "filter_name": "filter2", + "filter_param": "params2", + "last_estimates": groundtruths[:, -1, :] + 1, + }, + ] + + # Run the function and get the deviations matrix + all_deviations = determine_all_deviations( + results, + dummy_extract_mean, + dummy_distance_function, + groundtruths, + ) + + # Check the shape of the output matrices + assert len(all_deviations) == len(results) + + # Validate some of the results + np.testing.assert_allclose( + # Should be zeros as the lastEstimates match groundtruths + all_deviations[0], + [0, 0], + ) + np.testing.assert_allclose( + # Should be np.sqrt(2) away from groundtruths + all_deviations[1], + [np.sqrt(3), np.sqrt(3)], + ) + def test_configure_kf(self): filterParam = {"name": "kf", "parameter": None} scenarioParam = {