Skip to content

Commit

Permalink
scoring tool for pipeline (#282)
Browse files Browse the repository at this point in the history
* fixes and test scoring

* include RCall dep

* Minor change: increase target acceptance

* Minor fixes to run inference after changes in EpiAware

* scoring tool

* remove summarize

* change vline to fix to data availability times
  • Loading branch information
SamuelBrand1 authored Jun 14, 2024
1 parent 9bd7043 commit e506f7d
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 13 deletions.
1 change: 1 addition & 0 deletions pipeline/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
RCall = "6f49c342-dc21-5d91-9882-a32aef131414"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
Expand Down
6 changes: 5 additions & 1 deletion pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ module EpiAwarePipeline

using CSV, Dagger, DataFramesMeta, Dates, Distributions, DocStringExtensions, DrWatson,
EpiAware, Plots, Statistics, ADTypes, AbstractMCMC, Plots, JLD2, MCMCChains, Turing,
DynamicPPL, LogExpFunctions
DynamicPPL, LogExpFunctions, RCall

# Exported pipeline types
export AbstractEpiAwarePipeline, EpiAwarePipeline, RtwithoutRenewalPipeline,
Expand All @@ -38,6 +38,9 @@ export infer, generate_inference_results, map_inference_results, define_epiprob
# Exported functions: forecast functions
export define_forecast_epiprob, generate_forecasts

# Exported functions: scoring functions
export score_parameters

# Exported functions: plot functions
export plot_truth_data, plot_Rt

Expand All @@ -47,5 +50,6 @@ include("constructors/constructors.jl")
include("simulate/simulate.jl")
include("infer/infer.jl")
include("forecast/forecast.jl")
include("scoring/score_parameters.jl")
include("plot_functions.jl")
end
3 changes: 2 additions & 1 deletion pipeline/src/constructors/make_inference_method.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ function make_inference_method(
nruns_pthf::Integer = 4, maxiters_pthf::Integer = 100, nchains::Integer = 4)
return EpiMethod(
pre_sampler_steps = [ManyPathfinder(nruns = nruns_pthf, maxiters = maxiters_pthf)],
sampler = NUTSampler(adtype = AutoForwardDiff(), ndraws = ndraws,
sampler = NUTSampler(
target_acceptance = 0.9, adtype = AutoForwardDiff(), ndraws = ndraws,
nchains = nchains, mcmc_parallel = mcmc_ensemble)
)
end
4 changes: 2 additions & 2 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ struct InferenceConfig{T, F, I, L, E}
"Latent model type."
latent_model::L
"Case data"
case_data::Union{Vector{Integer, Missing}, Missing}
case_data::Union{Vector{Union{Integer, Missing}}, Missing}
"Time to fit on"
tspan::Tuple{Integer, Integer}
"Inference method."
Expand Down Expand Up @@ -84,7 +84,7 @@ function infer(config::InferenceConfig)
y_t = ismissing(config.case_data) ? missing : config.case_data[idxs]
inference_results = apply_method(epi_prob,
config.epimethod,
(y_t = y_t,)
(y_t = y_t,);
)
return Dict("inference_results" => inference_results, "epiprob" => epi_prob)
end
21 changes: 12 additions & 9 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ function generate_inference_results(
# produce or load inference results
prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
inference_config["latent_namemodels"].first * "_truth_gi_mean_" *
string(truthdata["truth_gi_mean"])
string(truthdata["truth_gi_mean"]) * "_used_gi_mean_" *
string(inference_config["gi_mean"])

inference_results, inferencefile = produce_or_load(
infer, config, datadir(datadir_name); prefix = prfx)
Expand Down Expand Up @@ -63,9 +64,10 @@ function generate_inference_results(
# produce or load inference results
prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
inference_config["latent_namemodels"].first * "_truth_gi_mean_" *
string(truthdata["truth_gi_mean"])
string(truthdata["truth_gi_mean"]) * "_used_gi_mean_" *
string(inference_config["gi_mean"])

datadir_name, io = mktemp(; cleanup = true)
datadir_name = mktempdir()

inference_results, inferencefile = produce_or_load(
infer, config, datadir_name; prefix = prfx)
Expand All @@ -76,16 +78,17 @@ end
Method for prior predictive modelling.
"""
function generate_inference_results(
truthdata, inference_config, pipeline::RtwithoutRenewalPriorPipeline;
tspan, inference_method,
prfix_name = "prior_observables", datadir_name = "epiaware_observables")
inference_config, pipeline::RtwithoutRenewalPriorPipeline;
tspan, prefix_name = "prior_observables")
config = InferenceConfig(
inference_config; case_data = missing, tspan, epimethod = inference_method)
inference_config; case_data = missing, tspan, epimethod = DirectSample())

# produce or load inference results
prfx = prfix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
prfx = prefix_name * "_igp_" * string(inference_config["igp"]) * "_latentmodel_" *
inference_config["latent_namemodels"].first * "_truth_gi_mean_" *
string(truthdata["truth_gi_mean"])
string(inference_config["gi_mean"])

datadir_name = mktempdir()

inference_results, inferencefile = produce_or_load(
infer, config, datadir(datadir_name); prefix = prfx)
Expand Down
45 changes: 45 additions & 0 deletions pipeline/src/scoring/score_parameters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""
Internal function for making a DataFrame for a given parameter using the
provided MCMC samples and the truth value.
"""
function _make_prediction_dataframe(param_name, samples, truth; model = "EpiAware")
x = samples[Symbol(param_name)][:]
DataFrame(predicted = x, observed = truth, model = model,
parameter = param_name, sample_id = 1:length(x))
end

"""
Internal function for scoring a DataFrame containing a prediction and truth value
for a parameter using the `scoringutils` package.
"""
function _score(df)
@rput df
R"""
library(scoringutils)
result = df |> as_forecast() |> score()
"""
@rget result
return result
end

"""
This function calculates standard scores provided by [`scoringutils`](https://epiforecasts.io/scoringutils/dev/)
for a set of parameters using the provided MCMC samples and the truth value.
The function returns a DataFrame containing a summary of the scores.
## Arguments
- `param_names`: Names of the parameter to score.
- `samples`: A `MCMCChains.Chains` object of samples.
- `truths`: Truth values for each parameter.
- `model`: (optional) The name of the model. Default is "EpiAware".
## Returns
- `result`: A DataFrame containing the summarized scores for the parameter.
"""
function score_parameters(param_names, samples, truths; model = "EpiAware")
df = mapreduce(vcat, param_names, truths) do param_name, truth
_make_prediction_dataframe(param_name, samples, truth; model = model)
end
return _score(df)
end
1 change: 1 addition & 0 deletions pipeline/src/scoring/scoring.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include("score_parameters.jl")
104 changes: 104 additions & 0 deletions pipeline/test/end-to-end/test_scoring.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
using DrWatson, Test
quickactivate(@__DIR__(), "EpiAwarePipeline")

@testset "run inference for random scenario and do scoring" begin
using EpiAwarePipeline, EpiAware, Plots, Statistics, RCall, DataFramesMeta
pipeline = EpiAwareExamplePipeline()
prior = RtwithoutRenewalPriorPipeline()

## Set up data generation on a random scenario

missing_padding = 14
lookahead = 21
n_observation_steps = 35
tspan_gen = (1, n_observation_steps + lookahead + missing_padding)
tspan_inf = (1, n_observation_steps + missing_padding)
inference_method = make_inference_method(pipeline; ndraws = 2000)
truth_data_config = make_truth_data_configs(pipeline)[1]
inference_configs = make_inference_configs(pipeline)
inference_config = rand(inference_configs)

## Generate truth data and plot
truth_sampling = generate_inference_results(inference_config, prior; tspan = tspan_gen)
truthdata = truth_sampling["inference_results"].generated.generated_y_t
##
plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data")
vline!(plt, [n_observation_steps + missing_padding + 0.5], label = "forecast start")

##Generate true Rt values
truth_It = truth_sampling["inference_results"].generated.I_t
truth_Rt = expected_Rt(truth_sampling["epiprob"].epi_model.data,
truth_sampling["inference_results"].generated.I_t)
plt_Rt = plot(truth_Rt, xlabel = "t", ylabel = "R_t", label = "truth Rt")
vline!(plt_Rt, [tspan_inf[2] + 0.5], label = "forecast start")

##
## Run inference
obs_truthdata = truthdata[tspan_inf[1]:tspan_inf[2]]

inference_results = generate_inference_results(
Dict("y_t" => obs_truthdata, "truth_gi_mean" => inference_config["gi_mean"]),
inference_config, pipeline; tspan = tspan_inf, inference_method)

## Make 21-day forecast

forecast_quantities = generate_forecasts(inference_results, lookahead)
forecast_y_t = mapreduce(hcat, forecast_quantities.generated) do gen
gen.generated_y_t
end
forecast_qs = mapreduce(hcat, [0.025, 0.25, 0.5, 0.75, 0.975]) do q
map(eachrow(forecast_y_t)) do row
if any(ismissing, row)
missing
else
quantile(row, q)
end
end
end
plt = scatter(truthdata, xlabel = "t", ylabel = "y_t", label = "truth data")
vline!(plt, [tspan_inf[2] + 0.5], label = "forecast start")
plot!(plt, forecast_qs, label = "forecast quantiles",
color = :grey, lw = [0.5 1.5 3 1.5 0.5])
plot!(plt, ylims = (-0.5, maximum(truthdata) * 1.25))
plot!(
plt, title = "Forecast of y_t", ylims = (
-0.5, maximum(skipmissing(truthdata)) * 1.55))
savefig(plt,
joinpath(@__DIR__(), "forecast_y_t.png")
)
display(plt)

## Make forecast plot for Z_t
infer_Z_t = mapreduce(hcat, inference_results["inference_results"].generated) do gen
gen.Z_t
end
forecast_Z_t = mapreduce(hcat, forecast_quantities.generated) do gen
gen.Z_t
end
plt_Zt = plot(
truth_sampling["inference_results"].generated.Z_t, lw = 3, color = :black, label = "truth Z_t")
plot!(plt_Zt, infer_Z_t, xlabel = "t", ylabel = "Z_t",
label = "", color = :grey, alpha = 0.05)
plot!((n_observation_steps + 1):size(forecast_Z_t, 1),
forecast_Z_t[(n_observation_steps + 1):end, :],
label = "", color = :red, alpha = 0.05)
vline!(plt_Zt, [n_observation_steps], label = "forecast start")

savefig(plt_Zt,
joinpath(@__DIR__(), "forecast_Z_t.png")
)
display(plt_Zt)

## Make forecast plot for Rt
param_names = forecast_quantities.samples.name_map[:parameters] .|> string
obs_yts = filter(str -> occursin("y_t", str), param_names)

scores = score_parameters(obs_yts, forecast_quantities.samples, truthdata)
plt_crps = scatter(scores.crps, xlabel = "t", ylabel = "CRPS", label = "CRPS")
savefig(plt_crps,
joinpath(@__DIR__(), "crps.png")
)
display(plt_crps)

@test scores isa DataFrame
end
1 change: 1 addition & 0 deletions pipeline/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ include("simulate/test_SimulationConfig.jl");
include("infer/test_InferenceConfig.jl");
include("infer/test_define_epiprob.jl");
include("forecast/test_forecast.jl");
include("scoring/test_score_parameters.jl");
11 changes: 11 additions & 0 deletions pipeline/test/scoring/test_score_parameters.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@testset "score_parameter tests" begin
using MCMCChains, EpiAwarePipeline

samples = MCMCChains.Chains(0.5 .+ randn(1000, 2, 1), [:a, :b])
truths = fill(0.5, 2)
result = score_parameters(["a", "b"], samples, truths)

@test result.parameter == ["a", "b"]
#Bias should be close to 0 in this example
@test all(result.bias .< 0.1)
end

0 comments on commit e506f7d

Please sign in to comment.