Skip to content

Commit

Permalink
Fix: figure 1 script (#555)
Browse files Browse the repository at this point in the history
* fix figure 1 script

* fix and refactor figure 1
  • Loading branch information
SamuelBrand1 authored Dec 16, 2024
1 parent ab73d9e commit ccbbed3
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 305 deletions.
55 changes: 26 additions & 29 deletions pipeline/scripts/create_figure1.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
## Script to make figure 1 and alternate latent models for SI
using Pkg
Pkg.activate(joinpath(@__DIR__(), ".."))
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta,
Statistics, Distributions, CSV, CairoMakie

using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta,
Statistics, Distributions, CSV

##
pipelines = [
SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(),
SmoothEndemicPipeline(), RoughEndemicPipeline()]
## Define scenarios and targets
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]
targets = ["log_I_t", "rt", "Rt"]
gi_means = [2.0, 10.0, 20.0]

## load some data and create a dataframe for the plot
truth_data_files = readdir(datadir("truth_data")) |>
strs -> filter(s -> occursin("jld2", s), strs)
analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame
truth_df = mapreduce(vcat, truth_data_files) do filename
D = load(joinpath(datadir("truth_data"), filename))
make_truthdata_dataframe(filename, D, pipelines)
end
truth_data_df = CSV.File(plotsdir("plotting_data/truthdata.csv")) |> DataFrame
prediction_df = CSV.File(plotsdir("plotting_data/predictions.csv")) |> DataFrame

# Define scenario titles and reference times for figure 1
scenario_dict = Dict(
Expand All @@ -28,21 +19,27 @@ scenario_dict = Dict(
)

target_dict = Dict(
"log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6)),
"rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1)),
"Rt" => (title = "Reproductive number", ylims = (-0.1, 3))
"log_I_t" => (title = "log(Incidence)",),
"rt" => (title = "Exp. growth rate",),
"Rt" => (title = "Reproductive number",)
)

latent_model_dict = Dict(
"wkly_rw" => (title = "Random walk",),
"wkly_ar" => (title = "AR(1)",),
"wkly_diff_ar" => (title = "Diff. AR(1)",)
"rw" => (title = "Random walk",),
"ar" => (title = "AR(1)",),
"diff_ar" => (title = "Diff. AR(1)",)
)

## `wkly_ar` is the default latent model which we show as figure 1, others are for SI

_ = map(latent_model_dict |> keys |> collect) do latent_model
fig = figureone(
truth_df, analysis_df, latent_model, scenario_dict, target_dict, latent_model_dict)
save(plotsdir("figure1_$(latent_model).png"), fig)
## `ar` is the default latent model which we show as figure 1, others are for SI

_ = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model
map(Iterators.product(gi_means, gi_means)) do (true_gi_choice, used_gi_choice)
fig = figureone(
prediction_df, truth_data_df, scenarios, targets; scenario_dict, target_dict,
latent_model_dict, latent_model, true_gi_choice, used_gi_choice)
# save(plotsdir("figure1_$(latent_model).png"), fig)
save(
plotsdir("figure1_$(latent_model)_trueGI_$(true_gi_choice)_usedGI_$(used_gi_choice).png"),
fig)
end
end
2 changes: 1 addition & 1 deletion pipeline/scripts/create_prediction_dataframe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dfs = mapreduce(vcat, scenarios) do scenario
mapreduce(vcat, files) do filename
output = load(joinpath(datadir("epiaware_observables"), scenario, filename))
try
make_prediction_dataframe_from_output(output, true_gi_mean)
make_prediction_dataframe_from_output(output, true_gi_mean, scenario)
catch e
@warn "Error in $filename"
push!(failed_configs, output["inference_config"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ A dataframe containing the prediction results.
"""
function make_prediction_dataframe_from_output(
output, true_mean_gi; qs = [0.025, 0.25, 0.5, 0.75, 0.975],
output, true_mean_gi, scenario; qs = [0.025, 0.25, 0.5, 0.75, 0.975],
transformation = oneexpy)
#Unpack the output
inference_config = output["inference_config"]
forecasts = output["forecast_results"]
#Get the scenario, IGP model, latent model and true mean GI
igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end]
scenario = inference_config["scenario"]
latent_model = inference_config["latent_model"]
used_gi_mean = inference_config["gi_mean"]
used_gi_std = inference_config["gi_std"]
Expand Down
Loading

0 comments on commit ccbbed3

Please sign in to comment.