Skip to content

Commit

Permalink
Merge pull request #161 from FlorianPfaff/feature/det_all_dev
Browse files Browse the repository at this point in the history
Evaluation: Added function to determine all deviations
  • Loading branch information
FlorianPfaff authored Sep 6, 2023
2 parents 30075d8 + f783623 commit 7c4bafd
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pyrecest/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .check_and_fix_params import check_and_fix_params
from .configure_for_filter import configure_for_filter
from .determine_all_deviations import determine_all_deviations
from .generate_groundtruth import generate_groundtruth
from .generate_measurements import generate_measurements
from .get_axis_label import get_axis_label
Expand All @@ -18,6 +19,7 @@
"configure_for_filter",
"perform_predict_update_cycles",
"iterate_configs_and_runs",
"determine_all_deviations",
"start_evaluation",
"get_axis_label",
"get_distance_function",
Expand Down
48 changes: 48 additions & 0 deletions pyrecest/evaluation/determine_all_deviations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import warnings

import numpy as np


def determine_all_deviations(
results,
extract_mean,
distance_function,
groundtruths,
mean_calculation_symm="",
):
if mean_calculation_symm != "":
raise NotImplementedError("Not implemented yet")

assert (
np.ndim(groundtruths) == 3
), "Assuming groundtruths to be a 3D array of shape (n_runs, n_timesteps, state_dimension)"

all_deviations_last_mat = np.empty((len(results), groundtruths.shape[0]))

for config, result_curr_config in enumerate(results):
for run in range(len(groundtruths)):
if "last_filter_states" not in result_curr_config:
final_estimate = result_curr_config["last_estimates"][run]
elif callable(extract_mean):
final_estimate = extract_mean(
result_curr_config["last_filter_states"][run]
)
else:
raise ValueError("No compatible mean extraction function given.")

if final_estimate is not None:
all_deviations_last_mat[config][run] = distance_function(
final_estimate, groundtruths[run, -1, :]
)
else:
warnings.warn("No estimate for this filter, setting error to inf.")
all_deviations_last_mat[config][run] = np.inf

if np.any(np.isinf(all_deviations_last_mat[config])):
print(
f"Warning: {result_curr_config['filterName']} with {result_curr_config['filterParams']} "
f"parameters apparently failed {np.sum(np.isinf(all_deviations_last_mat[config]))} "
"times. Check if this is plausible."
)

return all_deviations_last_mat
45 changes: 45 additions & 0 deletions pyrecest/tests/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from pyrecest.evaluation import (
check_and_fix_params,
configure_for_filter,
determine_all_deviations,
generate_groundtruth,
generate_measurements,
get_axis_label,
Expand Down Expand Up @@ -55,6 +56,50 @@ def test_generate_meas_R2(self):
self.scenario_param["n_meas_at_individual_time_step"][i],
)

def test_determine_all_deviations(self):
def dummy_extract_mean(x):
return x

def dummy_distance_function(x, y):
return np.linalg.norm(x - y)

groundtruths = np.array([[[1, 2, 3], [2, 3, 4]], [[11, 12, 13], [12, 13, 14]]])
results = [
{
"filter_name": "filter1",
"filter_param": "params1",
"last_estimates": groundtruths[:, -1, :],
},
{
"filter_name": "filter2",
"filter_param": "params2",
"last_estimates": groundtruths[:, -1, :] + 1,
},
]

# Run the function and get the deviations matrix
all_deviations = determine_all_deviations(
results,
dummy_extract_mean,
dummy_distance_function,
groundtruths,
)

# Check the shape of the output matrices
assert len(all_deviations) == len(results)

# Validate some of the results
np.testing.assert_allclose(
# Should be zeros as the lastEstimates match groundtruths
all_deviations[0],
[0, 0],
)
np.testing.assert_allclose(
# Should be np.sqrt(2) away from groundtruths
all_deviations[1],
[np.sqrt(3), np.sqrt(3)],
)

def test_configure_kf(self):
filterParam = {"name": "kf", "parameter": None}
scenarioParam = {
Expand Down

0 comments on commit 7c4bafd

Please sign in to comment.