Skip to content

Commit

Permalink
Implement basic scoring as part of inference (#401)
Browse files Browse the repository at this point in the history
* Update return in `simulate` with more unit tests

* reformat

* Extra variables in Inference Config and update unit tests

* Make the prefix a pipeline struct field

* fix underlying functions

* Add a simple crps function

* remove pipeline arg

* Add crps summary to infer along with end-to-end test update
  • Loading branch information
SamuelBrand1 authored Jul 25, 2024
1 parent 900ff01 commit 2b816a8
Show file tree
Hide file tree
Showing 19 changed files with 202 additions and 130 deletions.
4 changes: 2 additions & 2 deletions pipeline/src/EpiAwarePipeline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ export infer, generate_inference_results, map_inference_results, define_epiprob
export define_forecast_epiprob, generate_forecasts

# Exported functions: scoring functions
export score_parameters
export score_parameters, simple_crps, summarise_crps

# Exported functions: Analysis functions for constructing dataframes
export make_prediction_dataframe_from_output, make_truthdata_dataframe
Expand All @@ -64,7 +64,7 @@ include("constructors/constructors.jl")
include("simulate/simulate.jl")
include("infer/infer.jl")
include("forecast/forecast.jl")
include("scoring/score_parameters.jl")
include("scoring/scoring.jl")
include("analysis/analysis.jl")
include("mainplots/mainplots.jl")
include("plot_functions.jl")
Expand Down
22 changes: 15 additions & 7 deletions pipeline/src/infer/InferenceConfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,37 +18,42 @@ Inference configuration struct for specifying the parameters and models used in
- `InferenceConfig(inference_config::Dict; case_data, tspan, epimethod)`: Constructs an `InferenceConfig` object from a dictionary of configuration values.
"""
struct InferenceConfig{T, F, I, L, O, E, D <: Distribution, X <: Integer}
struct InferenceConfig{T, F, IGP, L, O, E, D <: Distribution, X <: Integer}
gi_mean::T
gi_std::T
igp::I
igp::IGP
latent_model::L
observation_model::O
case_data::Union{Vector{Union{Integer, Missing}}, Missing}
truth_I_t::Vector{T}
truth_I0::T
tspan::Tuple{Integer, Integer}
epimethod::E
transformation::F
log_I0_prior::D
lookahead::X

function InferenceConfig(igp, latent_model, observation_model; gi_mean, gi_std,
case_data, tspan, epimethod, transformation = exp, log_I0_prior, lookahead)
case_data, truth_I_t, truth_I0, tspan, epimethod,
transformation = exp, log_I0_prior, lookahead)
new{typeof(gi_mean), typeof(transformation),
typeof(igp), typeof(latent_model), typeof(observation_model),
typeof(epimethod), typeof(log_I0_prior), typeof(lookahead)}(
gi_mean, gi_std, igp, latent_model, observation_model,
case_data, tspan, epimethod, transformation, log_I0_prior, lookahead)
case_data, truth_I_t, truth_I0, tspan, epimethod, transformation, log_I0_prior, lookahead)
end

function InferenceConfig(
inference_config::Dict; case_data, tspan, epimethod)
inference_config::Dict; case_data, truth_I_t, truth_I0, tspan, epimethod)
InferenceConfig(
inference_config["igp"],
inference_config["latent_namemodels"].second,
inference_config["observation_model"];
gi_mean = inference_config["gi_mean"],
gi_std = inference_config["gi_std"],
case_data = case_data,
truth_I_t = truth_I_t,
truth_I0 = truth_I0,
tspan = tspan,
epimethod = epimethod,
log_I0_prior = inference_config["log_I0_prior"],
Expand Down Expand Up @@ -84,6 +89,9 @@ function infer(config::InferenceConfig)
forecast_results = generate_forecasts(
inference_results.samples, inference_results.data, epiprob, config.lookahead)

return Dict("inference_results" => inference_results, "epiprob" => epiprob,
"inference_config" => config, "forecast_results" => forecast_results)
score_results = summarise_crps(config, inference_results, forecast_results, epiprob)

return Dict("inference_results" => inference_results,
"epiprob" => epiprob, "inference_config" => config,
"forecast_results" => forecast_results, "score_results" => score_results)
end
8 changes: 5 additions & 3 deletions pipeline/src/infer/generate_inference_results.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function generate_inference_results(
tspan = make_tspan(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
config = InferenceConfig(
inference_config; case_data = truthdata["y_t"], tspan, epimethod = inference_method)
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan, epimethod = inference_method)

# produce or load inference results
prfx = _inference_prefix(truthdata, inference_config, pipeline)
Expand Down Expand Up @@ -51,8 +52,9 @@ function generate_inference_results(
truthdata, inference_config, pipeline::EpiAwareExamplePipeline; inference_method)
tspan = make_tspan(
pipeline; T = inference_config["T"], lookback = inference_config["lookback"])
config = InferenceConfig(inference_config; case_data = truthdata["y_t"],
tspan = tspan, epimethod = inference_method)
config = InferenceConfig(
inference_config; case_data = truthdata["y_t"], truth_I_t = truthdata["I_t"],
truth_I0 = truthdata["truth_I0"], tspan = tspan, epimethod = inference_method)

# produce or load inference results
prfx = _inference_prefix(truthdata, inference_config, pipeline)
Expand Down
5 changes: 1 addition & 4 deletions pipeline/src/infer/inference_prefix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ This is an internal method that generates the part of the prefix for the inferen
results file name from the pipeline.
"""
_prefix_from_pipeline(pipeline::AbstractEpiAwarePipeline) = "observables"
_prefix_from_pipeline(pipeline::SmoothOutbreakPipeline) = "smooth_outbreak"
_prefix_from_pipeline(pipeline::MeasuresOutbreakPipeline) = "measures_outbreak"
_prefix_from_pipeline(pipeline::SmoothEndemicPipeline) = "smooth_endemic"
_prefix_from_pipeline(pipeline::RoughEndemicPipeline) = "rough_endemic"
_prefix_from_pipeline(pipeline::AbstractRtwithoutRenewalPipeline) = pipeline.prefix

"""
This is an internal method that generates the prefix for the inference results file name.
Expand Down
4 changes: 4 additions & 0 deletions pipeline/src/pipeline/pipelinetypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Rt = make_Rt(pipeline) |> Rt -> plot(Rt,
nruns_pthf::Integer = 4
maxiters_pthf::Integer = 100
nchains::Integer = 4
prefix::String = "smooth_outbreak"
end

"""
Expand All @@ -55,6 +56,7 @@ The pipeline type for the Rt pipeline for an outbreak scenario where Rt has
nruns_pthf::Integer = 4
maxiters_pthf::Integer = 100
nchains::Integer = 4
prefix::String = "measures_outbreak"
end

"""
Expand All @@ -67,6 +69,7 @@ The pipeline type for the Rt pipeline for an endemic scenario where Rt changes i
nruns_pthf::Integer = 4
maxiters_pthf::Integer = 100
nchains::Integer = 4
prefix::String = "smooth_endemic"
end

"""
Expand All @@ -79,4 +82,5 @@ The pipeline type for the Rt pipeline for an endemic scenario where Rt changes i
nruns_pthf::Integer = 4
maxiters_pthf::Integer = 100
nchains::Integer = 4
prefix::String = "rough_endemic"
end
8 changes: 4 additions & 4 deletions pipeline/src/plot_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ Plot the true cases and latent infections. This is the default method for plotti
"""
function plot_truth_data(
data, config, pipeline::AbstractEpiAwarePipeline; plotsname = "truth_data")
plt_cases = scatter(
plt_cases = Plots.scatter(
data["y_t"], label = "Cases", xlabel = "Time", ylabel = "Daily cases",
title = "Cases and latent infections", legend = :bottomright)
plot!(plt_cases, data["I_t"], label = "True latent infections")
Plots.plot!(plt_cases, data["I_t"], label = "True latent infections")

if !isdir(plotsdir(plotsname))
mkdir(plotsdir(plotsname))
end

savefig(plt_cases, plotsdir(plotsname, savename(plotsname, config, "png")))
_plotsname = _simulate_prefix(pipeline) * plotsname
savefig(plt_cases, plotsdir(plotsname, savename(_plotsname, config, "png")))
return plt_cases
end

Expand Down
2 changes: 2 additions & 0 deletions pipeline/src/scoring/scoring.jl
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
include("score_parameters.jl")
include("simple_crps.jl")
include("summarise_crps.jl")
35 changes: 35 additions & 0 deletions pipeline/src/scoring/simple_crps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Compute the empirical Continuous Ranked Probability Score (CRPS) for a predictive
distribution defined by the samples `forecasts` with respect to an observed value
`observation`.
The CRPS is defined as the sum of the Mean Absolute Error (MAE) and a pseudo-entropy
term that measures the spread of the forecast distribution.
```math
CRPS = E[|Y - X|] - 0.5 E[|X - X'|]
```
Where `Y` is the observed value, and `X` and `X'` are two random variables drawn from the
forecast distribution.
# Arguments
- `forecasts`: A vector of forecasted values.
- `observation`: The observed value.
# Returns
- `crps`: The computed CRPS.
# Example
```julia
using EpiAwarePipeline
forecasts = randn(100)
observation = randn()
crps = simple_crps(forecasts, observation)
```
"""
function simple_crps(forecasts, observation)
@assert !isempty(forecasts) "Forecasts cannot be empty"
mae = mean(abs, forecasts .- observation)
pseudo_entropy = -0.5 * mean(abs, [x - y for x in forecasts, y in forecasts])
return mae + pseudo_entropy
end
69 changes: 69 additions & 0 deletions pipeline/src/scoring/summarise_crps.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
Summarizes the Continuous Ranked Probability Score (CRPS) for different processes based on the inference results.
# Arguments
- `inference_results`: A dictionary containing the inference results, including
the forecast results and inference configuration.
# Returns
A dictionary containing the summarized CRPS scores for different processes.
"""
function summarise_crps(config, inference_results, forecast_results, epiprob)
ts = config.tspan[1]:min(config.tspan[2] + config.lookahead, length(config.truth_I_t))
epidata = epiprob.epi_model.data

procs_names = (:log_I_t, :rt, :Rt, :I_t, :log_Rt)
scores_log_I_t, scores_rt, scores_Rt, scores_I_t, scores_log_Rt = _process_crps_scores(
procs_names, inference_results, forecast_results, config, ts, epidata)

scores_y_t, scores_log_y_t = _cases_crps_scores(forecast_results, config, ts)

return Dict("ts" => ts, "scores_log_I_t" => scores_log_I_t,
"scores_rt" => scores_rt, "scores_Rt" => scores_Rt,
"scores_I_t" => scores_I_t, "scores_log_Rt" => scores_log_Rt,
"scores_y_t" => scores_y_t, "scores_log_y_t" => scores_log_y_t)
end

"""
Internal method for calculating the CRPS scores for different processes.
"""
function _process_crps_scores(
procs_names, inference_results, forecast_results, config, ts, epidata)
map(procs_names) do process
# Calculate the processes for the truth data
true_Itminusone = ts[1] - 1 == 0 ? config.truth_I0 : config.truth_I_t[ts[1] - 1]
true_proc = calculate_processes(
config.truth_I_t[ts], true_Itminusone, epidata) |>
procs -> getfield(procs, process)
# predictions
gens = forecast_results.generated
log_I0s = inference_results.samples[:init_incidence]
predicted_proc = mapreduce(hcat, gens, log_I0s) do gen, logI0
I0 = exp(logI0)
It = gen.I_t
procs = calculate_processes(It, I0, epidata)
getfield(procs, process)
end
scores = [simple_crps(preds, true_proc[t])
for (t, preds) in enumerate(eachrow(predicted_proc))]
return scores
end
end

"""
Internal method for calculating the CRPS scores for observed cases and log(cases),
including the forecast score for future cases.
"""
function _cases_crps_scores(forecast_results, config, ts; jitter = 1e-6)
true_y_t = config.case_data[ts]
gens = forecast_results.generated
predicted_y_t = mapreduce(hcat, gens) do gen
gen.generated_y_t
end
scores_y_t = [simple_crps(preds, true_y_t[t])
for (t, preds) in enumerate(eachrow(predicted_y_t))]
scores_log_y_t = [simple_crps(log.(preds .+ jitter), log(true_y_t[t] + jitter))
for (t, preds) in enumerate(eachrow(predicted_y_t))]
return scores_y_t, scores_log_y_t
end
11 changes: 3 additions & 8 deletions pipeline/src/simulate/simulate_prefix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
Internal method for setting the prefix for the truth data file name.
"""
_simulate_prefix(pipeline::AbstractEpiAwarePipeline) = "truth_data"

_simulate_prefix(pipeline::SmoothOutbreakPipeline) = "truth_data_smooth_outbreak"

_simulate_prefix(pipeline::MeasuresOutbreakPipeline) = "truth_data_measures_outbreak"

_simulate_prefix(pipeline::SmoothEndemicPipeline) = "truth_data_smooth_endemic"

_simulate_prefix(pipeline::RoughEndemicPipeline) = "truth_data_rough_endemic"
function _simulate_prefix(pipeline::AbstractRtwithoutRenewalPipeline)
"truth_data_" * pipeline.prefix
end
13 changes: 6 additions & 7 deletions pipeline/src/utils/calculate_processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ estimate of `rt`.
"""
function _infection_seeding(
I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline; jitter = 1e-6)
I_t, I0, data::EpiData; jitter = 1e-6)
n = length(data.gen_int)
init_rt = _calc_rt(I_t[1:2] .+ jitter, I0 + jitter) |> x -> x[1]
[(I0 + jitter) * exp(-init_rt * (n - i)) for i in 1:n]
Expand All @@ -41,12 +41,11 @@ growth from the initial infections `I0` and the exponential growth rate `init_rt
- `I0`: Initial infections at time zero.
- `init_rt`: Initial exponential growth rate.
- `data::EpiData`: An instance of the `EpiData` type containing generation interval data.
- `pipeline::AbstractEpiAwarePipeline`: An instance of the `AbstractEpiAwarePipeline` type.
"""
function _calc_Rt(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline; jitter = 1e-6)
function _calc_Rt(I_t, I0, data::EpiData; jitter = 1e-6)
@assert I0 + jitter>0 "Initial infections must be positive definite."

aug_I_t = vcat(_infection_seeding(I_t .+ jitter, I0 + jitter, data, pipeline), I_t)
aug_I_t = vcat(_infection_seeding(I_t .+ jitter, I0 + jitter, data), I_t)

Rt = expected_Rt(data, aug_I_t)

Expand All @@ -69,9 +68,9 @@ from the first 7 time steps of `rt`.
A named tuple containing the calculated values for `log_I_t`, `rt`, and `Rt`.
"""
function calculate_processes(I_t, I0, data::EpiData, pipeline::AbstractEpiAwarePipeline)
function calculate_processes(I_t, I0, data::EpiData)
log_I_t = _calc_log_infections(I_t)
rt = _calc_rt(I_t, I0)
Rt = _calc_Rt(I_t, I0, data, pipeline)
return (; log_I_t, rt, Rt)
Rt = _calc_Rt(I_t, I0, data)
return (; log_I_t, rt, Rt, I_t, log_Rt = log.(Rt))
end
Loading

0 comments on commit 2b816a8

Please sign in to comment.