-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fixes and test scoring * include RCall dep * Minor change: increase target acceptance * Minor fixes to run inference after changes in EpiAware * scoring tool * remove summarize * change vline to fix to data availability times
- Loading branch information
1 parent
9bd7043
commit e506f7d
Showing
10 changed files
with
184 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
""" | ||
Internal function for making a DataFrame for a given parameter using the | ||
provided MCMC samples and the truth value. | ||
""" | ||
function _make_prediction_dataframe(param_name, samples, truth; model = "EpiAware") | ||
x = samples[Symbol(param_name)][:] | ||
DataFrame(predicted = x, observed = truth, model = model, | ||
parameter = param_name, sample_id = 1:length(x)) | ||
end | ||
|
||
""" | ||
Internal function for scoring a DataFrame containing a prediction and truth value | ||
for a parameter using the `scoringutils` package. | ||
""" | ||
function _score(df) | ||
@rput df | ||
R""" | ||
library(scoringutils) | ||
result = df |> as_forecast() |> score() | ||
""" | ||
@rget result | ||
return result | ||
end | ||
|
||
""" | ||
This function calculates standard scores provided by [`scoringutils`](https://epiforecasts.io/scoringutils/dev/) | ||
for a set of parameters using the provided MCMC samples and the truth value. | ||
The function returns a DataFrame containing a summary of the scores. | ||
## Arguments | ||
- `param_names`: Names of the parameter to score. | ||
- `samples`: A `MCMCChains.Chains` object of samples. | ||
- `truths`: Truth values for each parameter. | ||
- `model`: (optional) The name of the model. Default is "EpiAware". | ||
## Returns | ||
- `result`: A DataFrame containing the summarized scores for the parameter. | ||
""" | ||
function score_parameters(param_names, samples, truths; model = "EpiAware") | ||
df = mapreduce(vcat, param_names, truths) do param_name, truth | ||
_make_prediction_dataframe(param_name, samples, truth; model = model) | ||
end | ||
return _score(df) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include("score_parameters.jl") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
using DrWatson, Test | ||
quickactivate(@__DIR__(), "EpiAwarePipeline") | ||
|
||
@testset "run inference for random scenario and do scoring" begin | ||
using EpiAwarePipeline, EpiAware, Plots, Statistics, RCall, DataFramesMeta | ||
pipeline = EpiAwareExamplePipeline() | ||
prior = RtwithoutRenewalPriorPipeline() | ||
|
||
## Set up data generation on a random scenario | ||
|
||
missing_padding = 14 | ||
lookahead = 21 | ||
n_observation_steps = 35 | ||
tspan_gen = (1, n_observation_steps + lookahead + missing_padding) | ||
tspan_inf = (1, n_observation_steps + missing_padding) | ||
inference_method = make_inference_method(pipeline; ndraws = 2000) | ||
truth_data_config = make_truth_data_configs(pipeline)[1] | ||
inference_configs = make_inference_configs(pipeline) | ||
inference_config = rand(inference_configs) | ||
|
||
## Generate truth data and plot | ||
truth_sampling = generate_inference_results(inference_config, prior; tspan = tspan_gen) | ||
truthdata = truth_sampling["inference_results"].generated.generated_y_t | ||
## | ||
plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data") | ||
vline!(plt, [n_observation_steps + missing_padding + 0.5], label = "forecast start") | ||
|
||
##Generate true Rt values | ||
truth_It = truth_sampling["inference_results"].generated.I_t | ||
truth_Rt = expected_Rt(truth_sampling["epiprob"].epi_model.data, | ||
truth_sampling["inference_results"].generated.I_t) | ||
plt_Rt = plot(truth_Rt, xlabel = "t", ylabel = "R_t", label = "truth Rt") | ||
vline!(plt_Rt, [tspan_inf[2] + 0.5], label = "forecast start") | ||
|
||
## | ||
## Run inference | ||
obs_truthdata = truthdata[tspan_inf[1]:tspan_inf[2]] | ||
|
||
inference_results = generate_inference_results( | ||
Dict("y_t" => obs_truthdata, "truth_gi_mean" => inference_config["gi_mean"]), | ||
inference_config, pipeline; tspan = tspan_inf, inference_method) | ||
|
||
## Make 21-day forecast | ||
|
||
forecast_quantities = generate_forecasts(inference_results, lookahead) | ||
forecast_y_t = mapreduce(hcat, forecast_quantities.generated) do gen | ||
gen.generated_y_t | ||
end | ||
forecast_qs = mapreduce(hcat, [0.025, 0.25, 0.5, 0.75, 0.975]) do q | ||
map(eachrow(forecast_y_t)) do row | ||
if any(ismissing, row) | ||
missing | ||
else | ||
quantile(row, q) | ||
end | ||
end | ||
end | ||
plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data") | ||
vline!(plt, [tspan_inf[2] + 0.5], label = "forecast start") | ||
plot!(plt, forecast_qs, label = "forecast quantiles", | ||
color = :grey, lw = [0.5 1.5 3 1.5 0.5]) | ||
plot!(plt, ylims = (-0.5, maximum(truthdata) * 1.25)) | ||
plot!( | ||
plt, title = "Forecast of y_t", ylims = ( | ||
-0.5, maximum(skipmissing(truthdata)) * 1.55)) | ||
savefig(plt, | ||
joinpath(@__DIR__(), "forecast_y_t.png") | ||
) | ||
display(plt) | ||
|
||
## Make forecast plot for Z_t | ||
infer_Z_t = mapreduce(hcat, inference_results["inference_results"].generated) do gen | ||
gen.Z_t | ||
end | ||
forecast_Z_t = mapreduce(hcat, forecast_quantities.generated) do gen | ||
gen.Z_t | ||
end | ||
plt_Zt = plot( | ||
truth_sampling["inference_results"].generated.Z_t, lw = 3, color = :black, label = "truth Z_t") | ||
plot!(plt_Zt, infer_Z_t, xlabel = "t", ylabel = "Z_t", | ||
label = "", color = :grey, alpha = 0.05) | ||
plot!((n_observation_steps + 1):size(forecast_Z_t, 1), | ||
forecast_Z_t[(n_observation_steps + 1):end, :], | ||
label = "", color = :red, alpha = 0.05) | ||
vline!(plt_Zt, [n_observation_steps], label = "forecast start") | ||
|
||
savefig(plt_Zt, | ||
joinpath(@__DIR__(), "forecast_Z_t.png") | ||
) | ||
display(plt_Zt) | ||
|
||
## Make forecast plot for Rt | ||
param_names = forecast_quantities.samples.name_map[:parameters] .|> string | ||
obs_yts = filter(str -> occursin("y_t", str), param_names) | ||
|
||
scores = score_parameters(obs_yts, forecast_quantities.samples, truthdata) | ||
plt_crps = scatter(scores.crps, xlabel = "t", ylabel = "CRPS", label = "CRPS") | ||
savefig(plt_crps, | ||
joinpath(@__DIR__(), "crps.png") | ||
) | ||
display(plt_crps) | ||
|
||
@test scores isa DataFrame | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
@testset "score_parameter tests" begin | ||
using MCMCChains, EpiAwarePipeline | ||
|
||
samples = MCMCChains.Chains(0.5 .+ randn(1000, 2, 1), [:a, :b]) | ||
truths = fill(0.5, 2) | ||
result = score_parameters(["a", "b"], samples, truths) | ||
|
||
@test result.parameter == ["a", "b"] | ||
#Bias should be close to 0 in this example | ||
@test all(result.bias .< 0.1) | ||
end |