Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into issue509
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Dec 13, 2024
2 parents 3f5f4e5 + 12a21c6 commit 0d56246
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 16 deletions.
14 changes: 14 additions & 0 deletions pipeline/scripts/create_postprocessing_dataframes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using EpiAwarePipeline, EpiAware, JLD2, DrWatson, DataFramesMeta, CSV

## Define scenarios
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]

if !isfile(plotsdir("plotting_data/predictions.csv"))
@info "Prediction dataframe does not exist, generating now"
include("create_prediction_dataframe.jl")
end

if !isfile(plotsdir("plotting_data/truthdata.csv"))
@info "Truth dataframe does not exist, generating now"
include("create_truth_dataframe.jl")
end
6 changes: 0 additions & 6 deletions pipeline/scripts/create_prediction_dataframe.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta,
Statistics, Distributions, DrWatson, CSV

## Define scenarios
scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]

## Define true GI means
true_gi_means = [2.0, 10.0, 20.0]

Expand Down
11 changes: 11 additions & 0 deletions pipeline/scripts/create_truth_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
truth_df = mapreduce(vcat, scenarios) do scenario
truth_data_files = readdir(datadir("truth_data")) |>
strs -> filter(s -> occursin("jld2", s), strs) |>
strs -> filter(s -> occursin(scenario, s), strs)
mapreduce(vcat, truth_data_files) do filename
D = load(joinpath(datadir("truth_data"), filename))
make_truthdata_dataframe(D, scenario)
end
end

CSV.write(plotsdir("plotting_data/truthdata.csv"), truth_df)
15 changes: 5 additions & 10 deletions pipeline/src/analysis/make_truthdata_dataframe.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@

"""
make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0)
Create a DataFrame containing truth data for analysis.
# Arguments
- `filename::String`: The name of the file.
- `truth_data::Dict`: A dictionary containing truth data.
- `pipelines::Array`: An array of pipelines.
- `I_0::Float64`: Initial value for I_t (default: 100.0).
- `truth_data`: A dictionary containing truth data.
- `scenario`: Name of the truth data scenario.
# Returns
- `df::DataFrame`: A DataFrame containing the truth data.
- `df::DataFrame`: A DataFrame containing the summarised truth data.
"""
function make_truthdata_dataframe(filename, truth_data, pipelines; I_0 = 100.0)
function make_truthdata_dataframe(truth_data::Dict, scenario::String)
I_t = truth_data["I_t"]
I_0 = truth_data["truth_I0"]
true_mean_gi = truth_data["truth_gi_mean"]
log_It = _calc_log_infections(I_t)
rt = _calc_rt(I_t, I_0)
scenario = _get_scenario_from_filename(filename, pipelines)
truth_procs = (; log_I_t = log_It, rt, Rt = truth_data["truth_process"])

df = mapreduce(vcat, keys(truth_procs)) do target
Expand Down
18 changes: 18 additions & 0 deletions pipeline/test/analysis/make_truthdata_dataframe.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testset "make_truthdata_dataframe tests" begin
truth_data = Dict(
"I_t" => [10, 20, 30],
"truth_I0" => 5,
"truth_gi_mean" => 2.5,
"truth_process" => [1.0, 1.5, 2.0]
)
scenario = "test_scenario"

df = make_truthdata_dataframe(truth_data, scenario)

@test typeof(df) == DataFrame
@test size(df, 1) == 9 # 3 targets * 3 time points
@test all(df.Scenario .== scenario)
@test all(df.True_GI_Mean .== truth_data["truth_gi_mean"])
@test all(df.Target .==
["log_I_t", "log_I_t", "log_I_t", "rt", "rt", "rt", "Rt", "Rt", "Rt"])
end
1 change: 1 addition & 0 deletions pipeline/test/analysis/test_analysis.jl
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
include("make_prediction_dataframe_from_output.jl")
include("make_truthdata_dataframe.jl")

0 comments on commit 0d56246

Please sign in to comment.