From 12a21c69e9e7b8729986f176c3455d8dd5b7a077 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Fri, 13 Dec 2024 11:35:30 +0000 Subject: [PATCH] Issue 549: Extra post-processing steps (#551) * Update changelog.md * Update make_model_priors.jl * Make `define_epiprob` more modular _make_epidata can also be reused elsewhere * restructure analysis dataframe func * fix make_prediction_dataframe_from_output And add simple unit test with committed test data * Revert "fix make_prediction_dataframe_from_output" This reverts commit 7e6238bc275a79381b318204527748754932657d. * Reapply "fix make_prediction_dataframe_from_output" This reverts commit 14d29b42608132bc8603daf9f16926251677da45. * script to generate prediction dataframes * create_prediction_df refactor + fix * update to analyses failures * Update create_prediction_dataframe.jl * refactor make_truthdata_dataframe And add unit test * create script for generating postprocessed truth data And bundle into single script for DRY --- .../create_postprocessing_dataframes.jl | 14 ++++++++++++++ .../scripts/create_prediction_dataframe.jl | 6 ------ pipeline/scripts/create_truth_dataframe.jl | 11 +++++++++++ .../src/analysis/make_truthdata_dataframe.jl | 15 +++++---------- .../test/analysis/make_truthdata_dataframe.jl | 18 ++++++++++++++++++ pipeline/test/analysis/test_analysis.jl | 1 + 6 files changed, 49 insertions(+), 16 deletions(-) create mode 100644 pipeline/scripts/create_postprocessing_dataframes.jl create mode 100644 pipeline/scripts/create_truth_dataframe.jl create mode 100644 pipeline/test/analysis/make_truthdata_dataframe.jl diff --git a/pipeline/scripts/create_postprocessing_dataframes.jl b/pipeline/scripts/create_postprocessing_dataframes.jl new file mode 100644 index 000000000..4b3ea1568 --- /dev/null +++ b/pipeline/scripts/create_postprocessing_dataframes.jl @@ -0,0 +1,14 @@ +using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV + +## Define scenarios +scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] + +if !isfile(plotsdir("plotting_data/predictions.csv")) + @info "Prediction dataframe does not exist, generating now" + include("create_prediction_dataframe.jl") +end + +if !isfile(plotsdir("plotting_data/truthdata.csv")) + @info "Truth dataframe does not exist, generating now" + include("create_truth_dataframe.jl") +end diff --git a/pipeline/scripts/create_prediction_dataframe.jl b/pipeline/scripts/create_prediction_dataframe.jl index 9d02c6de1..fb70a2458 100644 --- a/pipeline/scripts/create_prediction_dataframe.jl +++ b/pipeline/scripts/create_prediction_dataframe.jl @@ -1,9 +1,3 @@ -using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta, - Statistics, Distributions, DrWatson, CSV - -## Define scenarios -scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] - ## Define true GI means true_gi_means = [2.0, 10.0, 20.0] diff --git a/pipeline/scripts/create_truth_dataframe.jl b/pipeline/scripts/create_truth_dataframe.jl new file mode 100644 index 000000000..2dc7125bc --- /dev/null +++ b/pipeline/scripts/create_truth_dataframe.jl @@ -0,0 +1,11 @@ +truth_df = mapreduce(vcat, scenarios) do scenario + truth_data_files = readdir(datadir("truth_data")) |> + strs -> filter(s -> occursin("jld2", s), strs) |> + strs -> filter(s -> occursin(scenario, s), strs) + mapreduce(vcat, truth_data_files) do filename + D = load(joinpath(datadir("truth_data"), filename)) + make_truthdata_dataframe(D, scenario) + end +end + +CSV.write(plotsdir("plotting_data/truthdata.csv"), truth_df) diff --git a/pipeline/src/analysis/make_truthdata_dataframe.jl b/pipeline/src/analysis/make_truthdata_dataframe.jl index 1adedeb9b..66159687a 100644 --- a/pipeline/src/analysis/make_truthdata_dataframe.jl +++ b/pipeline/src/analysis/make_truthdata_dataframe.jl @@ -1,25 +1,20 @@ - """ - make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) - Create a DataFrame containing truth data for analysis. # Arguments -- `filename::String`: The name of the file. -- `truth_data::Dict`: A dictionary containing truth data. -- `pipelines::Array`: An array of pipelines. -- `I_0::Float64`: Initial value for I_t (default: 100.0). +- `truth_data`: A dictionary containing truth data. +- `scenario`: Name of the truth data scenario. # Returns -- `df::DataFrame`: A DataFrame containing the truth data. +- `df::DataFrame`: A DataFrame containing the summarised truth data. """ -function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0) +function make_truthdata_dataframe(truth_data::Dict, scenario::String) I_t = truth_data["I_t"] + I_0 = truth_data["truth_I0"] true_mean_gi = truth_data["truth_gi_mean"] log_It = _calc_log_infections(I_t) rt = _calc_rt(I_t, I_0) - scenario = _get_scenario_from_filename(filename, pipelines) truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"]) df = mapreduce(vcat, keys(truth_procs)) do target diff --git a/pipeline/test/analysis/make_truthdata_dataframe.jl b/pipeline/test/analysis/make_truthdata_dataframe.jl new file mode 100644 index 000000000..866427ce0 --- /dev/null +++ b/pipeline/test/analysis/make_truthdata_dataframe.jl @@ -0,0 +1,18 @@ +@testset "make_truthdata_dataframe tests" begin + truth_data = Dict( + "I_t" => [10, 20, 30], + "truth_I0" => 5, + "truth_gi_mean" => 2.5, + "truth_process" => [1.0, 1.5, 2.0] + ) + scenario = "test_scenario" + + df = make_truthdata_dataframe(truth_data, scenario) + + @test typeof(df) == DataFrame + @test size(df, 1) == 9 # 3 targets * 3 time points + @test all(df.Scenario .== scenario) + @test all(df.True_GI_Mean .== truth_data["truth_gi_mean"]) + @test all(df.Target .== + ["log_I_t", "log_I_t", "log_I_t", "rt", "rt", "rt", "Rt", "Rt", "Rt"]) +end diff --git a/pipeline/test/analysis/test_analysis.jl b/pipeline/test/analysis/test_analysis.jl index 58076ec5d..1e2608aea 100644 --- a/pipeline/test/analysis/test_analysis.jl +++ b/pipeline/test/analysis/test_analysis.jl @@ -1 +1,2 @@ include("make_prediction_dataframe_from_output.jl") +include("make_truthdata_dataframe.jl")