diff --git a/pipeline/scripts/create_postprocessing_dataframes.jl b/pipeline/scripts/create_postprocessing_dataframes.jl new file mode 100644 index 000000000..4b3ea1568 --- /dev/null +++ b/pipeline/scripts/create_postprocessing_dataframes.jl @@ -0,0 +1,14 @@ +using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV + +## Define scenarios +scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] + +if !isfile(plotsdir("plotting_data/predictions.csv")) + @info "Prediction dataframe does not exist, generating now" + include("create_prediction_dataframe.jl") +end + +if !isfile(plotsdir("plotting_data/truthdata.csv")) + @info "Truth dataframe does not exist, generating now" + include("create_truth_dataframe.jl") +end diff --git a/pipeline/scripts/create_prediction_dataframe.jl b/pipeline/scripts/create_prediction_dataframe.jl index 9d02c6de1..fb70a2458 100644 --- a/pipeline/scripts/create_prediction_dataframe.jl +++ b/pipeline/scripts/create_prediction_dataframe.jl @@ -1,9 +1,3 @@ -using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta, - Statistics, Distributions, DrWatson, CSV - -## Define scenarios -scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] - ## Define true GI means true_gi_means = [2.0, 10.0, 20.0] diff --git a/pipeline/scripts/create_truth_dataframe.jl b/pipeline/scripts/create_truth_dataframe.jl new file mode 100644 index 000000000..2dc7125bc --- /dev/null +++ b/pipeline/scripts/create_truth_dataframe.jl @@ -0,0 +1,11 @@ +truth_df = mapreduce(vcat, scenarios) do scenario + truth_data_files = readdir(datadir("truth_data")) |> + strs -> filter(s -> occursin("jld2", s), strs) |> + strs -> filter(s -> occursin(scenario, s), strs) + mapreduce(vcat, truth_data_files) do filename + D = load(joinpath(datadir("truth_data"), filename)) + make_truthdata_dataframe(D, scenario) + end +end + +CSV.write(plotsdir("plotting_data/truthdata.csv"), truth_df) diff --git a/pipeline/src/analysis/make_truthdata_dataframe.jl b/pipeline/src/analysis/make_truthdata_dataframe.jl index 1adedeb9b..66159687a 100644 --- a/pipeline/src/analysis/make_truthdata_dataframe.jl +++ b/pipeline/src/analysis/make_truthdata_dataframe.jl @@ -1,25 +1,20 @@ - """ - make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) - Create a DataFrame containing truth data for analysis. # Arguments -- `filename::String`: The name of the file. -- `truth_data::Dict`: A dictionary containing truth data. -- `pipelines::Array`: An array of pipelines. -- `I_0::Float64`: Initial value for I_t (default: 100.0). +- `truth_data`: A dictionary containing truth data. +- `scenario`: Name of the truth data scenario. # Returns -- `df::DataFrame`: A DataFrame containing the truth data. +- `df::DataFrame`: A DataFrame containing the summarised truth data. """ -function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) +function make_truthdata_dataframe(truth_data::Dict, scenario::String) I_t = truth_data["I_t"] + I_0 = truth_data["truth_I0"] true_mean_gi = truth_data["truth_gi_mean"] log_It = _calc_log_infections(I_t) rt = _calc_rt(I_t, I_0) - scenario = _get_scenario_from_filename(filename, pipelines) truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"]) df = mapreduce(vcat, keys(truth_procs)) do target diff --git a/pipeline/test/analysis/make_truthdata_dataframe.jl b/pipeline/test/analysis/make_truthdata_dataframe.jl new file mode 100644 index 000000000..866427ce0 --- /dev/null +++ b/pipeline/test/analysis/make_truthdata_dataframe.jl @@ -0,0 +1,18 @@ +@testset "make_truthdata_dataframe tests" begin + truth_data = Dict( + "I_t" => [10, 20, 30], + "truth_I0" => 5, + "truth_gi_mean" => 2.5, + "truth_process" => [1.0, 1.5, 2.0] + ) + scenario = "test_scenario" + + df = make_truthdata_dataframe(truth_data, scenario) + + @test typeof(df) == DataFrame + @test size(df, 1) == 9 # 3 targets * 3 time points + @test all(df.Scenario .== scenario) + @test all(df.True_GI_Mean .== truth_data["truth_gi_mean"]) + @test all(df.Target .== + ["log_I_t", "log_I_t", "log_I_t", "rt", "rt", "rt", "Rt", "Rt", "Rt"]) +end diff --git a/pipeline/test/analysis/test_analysis.jl b/pipeline/test/analysis/test_analysis.jl index 58076ec5d..1e2608aea 100644 --- a/pipeline/test/analysis/test_analysis.jl +++ b/pipeline/test/analysis/test_analysis.jl @@ -1 +1,2 @@ include("make_prediction_dataframe_from_output.jl") +include("make_truthdata_dataframe.jl")