-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement basic scoring as part of inference (#401)
* Update return in `simulate` with more unit tests * reformat * Extra variables in Inference Config and update unit tests * Make the prefix a pipeline struct field * fix underlying functions * Add a simple crps function * remove pipeline arg * Add crps summary to infer along with end-to-end test update
- Loading branch information
1 parent
900ff01
commit 2b816a8
Showing
19 changed files
with
202 additions
and
130 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
include("score_parameters.jl") | ||
include("simple_crps.jl") | ||
include("summarise_crps.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
""" | ||
Compute the empirical Continuous Ranked Probability Score (CRPS) for a predictive | ||
distribution defined by the samples `forecasts` with respect to an observed value | ||
`observation`. | ||
The CRPS is defined as the sum of the Mean Absolute Error (MAE) and a pseudo-entropy | ||
term that measures the spread of the forecast distribution. | ||
```math | ||
CRPS = E[|Y - X|] - 0.5 E[|X - X'|] | ||
``` | ||
Where `Y` is the observed value, and `X` and `X'` are two random variables drawn from the | ||
forecast distribution. | ||
# Arguments | ||
- `forecasts`: A vector of forecasted values. | ||
- `observation`: The observed value. | ||
# Returns | ||
- `crps`: The computed CRPS. | ||
# Example | ||
```julia | ||
using EpiAwarePipeline | ||
forecasts = randn(100) | ||
observation = randn() | ||
crps = simple_crps(forecasts, observation) | ||
``` | ||
""" | ||
function simple_crps(forecasts, observation) | ||
@assert !isempty(forecasts) "Forecasts cannot be empty" | ||
mae = mean(abs, forecasts .- observation) | ||
pseudo_entropy = -0.5 * mean(abs, [x - y for x in forecasts, y in forecasts]) | ||
return mae + pseudo_entropy | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
""" | ||
Summarizes the Continuous Ranked Probability Score (CRPS) for different processes based on the inference results. | ||
# Arguments | ||
- `inference_results`: A dictionary containing the inference results, including | ||
the forecast results and inference configuration. | ||
# Returns | ||
A dictionary containing the summarized CRPS scores for different processes. | ||
""" | ||
function summarise_crps(config, inference_results, forecast_results, epiprob) | ||
ts = config.tspan[1]:min(config.tspan[2] + config.lookahead, length(config.truth_I_t)) | ||
epidata = epiprob.epi_model.data | ||
|
||
procs_names = (:log_I_t, :rt, :Rt, :I_t, :log_Rt) | ||
scores_log_I_t, scores_rt, scores_Rt, scores_I_t, scores_log_Rt = _process_crps_scores( | ||
procs_names, inference_results, forecast_results, config, ts, epidata) | ||
|
||
scores_y_t, scores_log_y_t = _cases_crps_scores(forecast_results, config, ts) | ||
|
||
return Dict("ts" => ts, "scores_log_I_t" => scores_log_I_t, | ||
"scores_rt" => scores_rt, "scores_Rt" => scores_Rt, | ||
"scores_I_t" => scores_I_t, "scores_log_Rt" => scores_log_Rt, | ||
"scores_y_t" => scores_y_t, "scores_log_y_t" => scores_log_y_t) | ||
end | ||
|
||
""" | ||
Internal method for calculating the CRPS scores for different processes. | ||
""" | ||
function _process_crps_scores( | ||
procs_names, inference_results, forecast_results, config, ts, epidata) | ||
map(procs_names) do process | ||
# Calculate the processes for the truth data | ||
true_Itminusone = ts[1] - 1 == 0 ? config.truth_I0 : config.truth_I_t[ts[1] - 1] | ||
true_proc = calculate_processes( | ||
config.truth_I_t[ts], true_Itminusone, epidata) |> | ||
procs -> getfield(procs, process) | ||
# predictions | ||
gens = forecast_results.generated | ||
log_I0s = inference_results.samples[:init_incidence] | ||
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0 | ||
I0 = exp(logI0) | ||
It = gen.I_t | ||
procs = calculate_processes(It, I0, epidata) | ||
getfield(procs, process) | ||
end | ||
scores = [simple_crps(preds, true_proc[t]) | ||
for (t, preds) in enumerate(eachrow(predicted_proc))] | ||
return scores | ||
end | ||
end | ||
|
||
""" | ||
Internal method for calculating the CRPS scores for observed cases and log(cases), | ||
including the forecast score for future cases. | ||
""" | ||
function _cases_crps_scores(forecast_results, config, ts; jitter = 1e-6) | ||
true_y_t = config.case_data[ts] | ||
gens = forecast_results.generated | ||
predicted_y_t = mapreduce(hcat, gens) do gen | ||
gen.generated_y_t | ||
end | ||
scores_y_t = [simple_crps(preds, true_y_t[t]) | ||
for (t, preds) in enumerate(eachrow(predicted_y_t))] | ||
scores_log_y_t = [simple_crps(log.(preds .+ jitter), log(true_y_t[t] + jitter)) | ||
for (t, preds) in enumerate(eachrow(predicted_y_t))] | ||
return scores_y_t, scores_log_y_t | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.