From ccbbed3a5de81c4dab3a1192722635fe7ce31b03 Mon Sep 17 00:00:00 2001 From: Samuel Brand <48288458+SamuelBrand1@users.noreply.github.com> Date: Mon, 16 Dec 2024 11:00:47 +0000 Subject: [PATCH] Fix: figure 1 script (#555) * fix figure 1 script * fix and refactor figure 1 --- pipeline/scripts/create_figure1.jl | 55 ++- .../scripts/create_prediction_dataframe.jl | 2 +- .../make_prediction_dataframe_from_output.jl | 3 +- pipeline/src/plotting/figureone.jl | 382 +++++------------- 4 files changed, 137 insertions(+), 305 deletions(-) diff --git a/pipeline/scripts/create_figure1.jl b/pipeline/scripts/create_figure1.jl index a4dd6ed12..46fbb39f4 100644 --- a/pipeline/scripts/create_figure1.jl +++ b/pipeline/scripts/create_figure1.jl @@ -1,23 +1,14 @@ -## Script to make figure 1 and alternate latent models for SI -using Pkg -Pkg.activate(joinpath(@__DIR__(), "..")) +using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, DataFramesMeta, + Statistics, Distributions, CSV, CairoMakie -using EpiAwarePipeline, EpiAware, AlgebraOfGraphics, JLD2, DrWatson, Plots, DataFramesMeta, - Statistics, Distributions, CSV - -## -pipelines = [ - SmoothOutbreakPipeline(), MeasuresOutbreakPipeline(), - SmoothEndemicPipeline(), RoughEndemicPipeline()] +## Define scenarios and targets +scenarios = ["measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"] +targets = ["log_I_t", "rt", "Rt"] +gi_means = [2.0, 10.0, 20.0] ## load some data and create a dataframe for the plot -truth_data_files = readdir(datadir("truth_data")) |> - strs -> filter(s -> occursin("jld2", s), strs) -analysis_df = CSV.File(plotsdir("analysis_df.csv")) |> DataFrame -truth_df = mapreduce(vcat, truth_data_files) do filename - D = load(joinpath(datadir("truth_data"), filename)) - make_truthdata_dataframe(filename, D, pipelines) -end +truth_data_df = CSV.File(plotsdir("plotting_data/truthdata.csv")) |> DataFrame +prediction_df = CSV.File(plotsdir("plotting_data/predictions.csv")) |> DataFrame # Define scenario titles and reference times for figure 1 scenario_dict = Dict( @@ -28,21 +19,27 @@ scenario_dict = Dict( ) target_dict = Dict( - "log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6)), - "rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1)), - "Rt" => (title = "Reproductive number", ylims = (-0.1, 3)) + "log_I_t" => (title = "log(Incidence)",), + "rt" => (title = "Exp. growth rate",), + "Rt" => (title = "Reproductive number",) ) latent_model_dict = Dict( - "wkly_rw" => (title = "Random walk",), - "wkly_ar" => (title = "AR(1)",), - "wkly_diff_ar" => (title = "Diff. AR(1)",) + "rw" => (title = "Random walk",), + "ar" => (title = "AR(1)",), + "diff_ar" => (title = "Diff. AR(1)",) ) -## `wkly_ar` is the default latent model which we show as figure 1, others are for SI - -_ = map(latent_model_dict |> keys |> collect) do latent_model - fig = figureone( - truth_df, analysis_df, latent_model, scenario_dict, target_dict, latent_model_dict) - save(plotsdir("figure1_$(latent_model).png"), fig) +## `ar` is the default latent model which we show as figure 1, others are for SI + +_ = mapreduce(vcat, latent_model_dict |> keys |> collect) do latent_model + map(Iterators.product(gi_means, gi_means)) do (true_gi_choice, used_gi_choice) + fig = figureone( + prediction_df, truth_data_df, scenarios, targets; scenario_dict, target_dict, + latent_model_dict, latent_model, true_gi_choice, used_gi_choice) + # save(plotsdir("figure1_$(latent_model).png"), fig) + save( + plotsdir("figure1_$(latent_model)_trueGI_$(true_gi_choice)_usedGI_$(used_gi_choice).png"), + fig) + end end diff --git a/pipeline/scripts/create_prediction_dataframe.jl b/pipeline/scripts/create_prediction_dataframe.jl index fb70a2458..c0e61409d 100644 --- a/pipeline/scripts/create_prediction_dataframe.jl +++ b/pipeline/scripts/create_prediction_dataframe.jl @@ -14,7 +14,7 @@ dfs = mapreduce(vcat, scenarios) do scenario mapreduce(vcat, files) do filename output = load(joinpath(datadir("epiaware_observables"), scenario, filename)) try - make_prediction_dataframe_from_output(output, true_gi_mean) + make_prediction_dataframe_from_output(output, true_gi_mean, scenario) catch e @warn "Error in $filename" push!(failed_configs, output["inference_config"]) diff --git a/pipeline/src/analysis/make_prediction_dataframe_from_output.jl b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl index c1980a7cd..cc68012ba 100644 --- a/pipeline/src/analysis/make_prediction_dataframe_from_output.jl +++ b/pipeline/src/analysis/make_prediction_dataframe_from_output.jl @@ -14,14 +14,13 @@ A dataframe containing the prediction results. """ function make_prediction_dataframe_from_output( - output, true_mean_gi; qs = [0.025, 0.25, 0.5, 0.75, 0.975], + output, true_mean_gi, scenario; qs = [0.025, 0.25, 0.5, 0.75, 0.975], transformation = oneexpy) #Unpack the output inference_config = output["inference_config"] forecasts = output["forecast_results"] #Get the scenario, IGP model, latent model and true mean GI igp_model = inference_config["igp"] |> igp_name -> split(igp_name, ".")[end] - scenario = inference_config["scenario"] latent_model = inference_config["latent_model"] used_gi_mean = inference_config["gi_mean"] used_gi_std = inference_config["gi_std"] diff --git a/pipeline/src/plotting/figureone.jl b/pipeline/src/plotting/figureone.jl index 5e05dbc6a..1a040bd15 100644 --- a/pipeline/src/plotting/figureone.jl +++ b/pipeline/src/plotting/figureone.jl @@ -1,298 +1,134 @@ """ -Internal method for creating a figure of model inference for a specific scenario - using the given analysis data. +Plot predictions on the given axis (`ax`) based on the provided parameters. # Arguments -- `analysis_df`: The analysis data frame. -- `scenario`: The scenario to plot. -- `reference_time`: The reference time. -- `true_gi_choice`: The true GI choice. -- `used_gi_choice`: The used GI choice. -- `lower_sym`: The symbol for the lower quantile (default is `:q_025`). -- `upper_sym`: The symbol for the upper quantile (default is `:q_975`). - -# Returns -- `plt_model`: The plot object. - -""" -function _figure_one_scenario(analysis_df, scenario; reference_time, true_gi_choice, - used_gi_choice, lower_sym = :q_025, upper_sym = :q_975) - model_plotting_data = analysis_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> - df -> @subset(df, - :Reference_Time.==reference_time) |> - df -> @subset(df, :Scenario.==scenario) |> - data - - plt_model = model_plotting_data * - mapping(:target_times => "T", :q_5 => "Process values", - col = :Target, row = :IGP_Model => "IGP model", - color = :Latent_Model => "Latent model") * - mapping(lower = lower_sym, upper = upper_sym) * visual(LinesFill) - - return plt_model -end - -""" -Internal method that generates a plot of the truth data for a specific scenario. - -## Arguments -- `truth_df`: The truth data DataFrame. -- `scenario`: The scenario for which the truth data should be plotted. -- `true_gi_choice`: The choice of true GI mean. - -## Returns -- `plt_truth`: The plot of the truth data. - -""" -function _figure_scenario_truth_data(truth_df, scenario; true_gi_choice) - truth_plotting_data = truth_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Scenario.==scenario) |> data - plt_truth = truth_plotting_data * - mapping(:target_times => "T", :target_values => "values", - col = :Target, color = :Latent_Model => "Latent Model") * - visual(Lines) - return plt_truth -end +- `ax`: The axis on which to plot the predictions. +- `predictions`: DataFrame containing the prediction data. +- `scenario`: The scenario to filter the predictions. +- `target`: The target to filter the predictions. +- `reference_time`: The reference time to filter the predictions. +- `latent_model`: The latent model to filter the predictions. +- `igps`: A list of IGP models to plot. Default is `["DirectInfections", "ExpGrowthRate", "Renewal"]`. +- `true_gi_choice`: The true generation interval mean to filter the predictions. Default is `2.0`. +- `used_gi_choice`: The used generation interval mean to filter the predictions. Default is `2.0`. +- `colors`: A list of colors for each IGP model. Default is `[:red, :blue, :green]`. +- `iqr_alpha`: The alpha value for the interquartile range bands. Default is `0.3`. + +# Description +This function filters the `predictions` DataFrame based on the provided parameters and plots + the predictions on the given axis (`ax`). It plots the median prediction line and two + bands representing the interquartile range (IQR) and the 95% prediction interval for + each IGP model specified in `igps`. """ -Generate a version figure 1 showing the analysis and truth data for different scenarios _and_ -different latent process models. - -## Arguments -- `truth_df`: DataFrame containing the truth data. -- `analysis_df`: DataFrame containing the analysis data. -- `scenario_dict`: Dictionary containing information about the scenarios. - -## Keyword Arguments -- `fig_kws`: Keyword arguments for the Figure object. Default is `(; size = (1000, 2000))`. -- `true_gi_choice`: Value for the true generation interval choice. Default is `10.0`. -- `used_gi_choice`: Value for the used generation interval choice. Default is `10.0`. -- `legend_title`: Title for the legend. Default is `"Process type"`. - -## Returns -- `fig`: Figure object containing the generated figure. - -""" -function figureone_with_latent_model( - truth_df, analysis_df, scenario_dict; fig_kws = (; size = (1000, 2000)), - true_gi_choice = 10.0, used_gi_choice = 10.0, legend_title = "Process type") - # Perform checks on the dataframes - _dataframe_checks(truth_df, analysis_df, scenario_dict) - # Treat the truth data as a Latent model option - truth_df[!, "Latent_Model"] .= "Truth data" - - scenarios = analysis_df.Scenario |> unique - plt_truth_vect = map(scenarios) do scenario - _figure_scenario_truth_data(truth_df, scenario; true_gi_choice) - end - plt_analysis_vect = map(scenarios) do scenario - _figure_one_scenario( - analysis_df, scenario; reference_time = scenario_dict[scenario].T, - true_gi_choice, used_gi_choice) - end - - fig = Figure(; fig_kws...) - leg = nothing - for (i, scenario) in enumerate(scenarios) - sf = fig[i, :] - ag = draw!( - sf, plt_analysis_vect[i] + plt_truth_vect[i], facet = (; linkyaxes = :none)) - leg = AlgebraOfGraphics.compute_legend(ag) - Label(sf[0, :], scenario_dict[scenario].title, fontsize = 24, font = :bold) +function _plot_predictions!( + ax, predictions, scenario, target, reference_time, latent_model; + igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + true_gi_choice = 2.0, used_gi_choice = 2.0, colors = [:red, :blue, :green], + iqr_alpha = 0.3) + pred = predictions |> + df -> @subset(df, :Latent_Model.==latent_model) |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> + df -> @subset(df, :Reference_Time.==reference_time) |> + df -> @subset(df, :Scenario.==scenario) |> + df -> @subset(df, :Target.==target) + for (c, igp) in zip(colors, igps) + x = pred[pred.IGP_Model .== igp, "target_times"] + y = pred[pred.IGP_Model .== igp, "q_5"] + upr1 = pred[pred.IGP_Model .== igp, "q_75"] + upr2 = pred[pred.IGP_Model .== igp, "q_975"] + lwr1 = pred[pred.IGP_Model .== igp, "q_25"] + lwr2 = pred[pred.IGP_Model .== igp, "q_025"] + if length(x) > 0 + lines!(ax, x, y, color = c, label = igp, linewidth = 3) + band!(ax, x, lwr1, upr1, color = (c, iqr_alpha)) + band!(ax, x, lwr2, upr2, color = (c, iqr_alpha / 2)) + end end - - Label(fig[:, 0], "Process values", fontsize = 28, font = :bold, rotation = pi / 2) - Label(fig[:, 2], "Infection generating process", - fontsize = 24, font = :bold, rotation = -pi / 2) - _leg = (leg[1], leg[2], [legend_title]) - Legend(fig[:, 3], _leg...) - - return fig + return nothing end """ -Internal method for creating a model panel plot for Figure One. - -This function takes in various parameters to filter the `analysis_df` DataFrame and create a model panel plot for Figure One. -The filtered DataFrame is used to generate the plot using the `model_plotting_data` variable. -The plot includes process values, color-coded by the infection generating process, -and credible intervals defined by `lower_sym` and `upper_sym`. +Plot the truth data on the given axis. -## Arguments -- `analysis_df`: The DataFrame containing the analysis data. -- `scenario`: The scenario to filter the DataFrame. -- `target`: The target to filter the DataFrame. -- `latentmodel`: The latent model to filter the DataFrame. -- `reference_time`: The reference time to filter the DataFrame. -- `true_gi_choice`: The true GI mean value to filter the DataFrame. -- `used_gi_choice`: The used GI mean value to filter the DataFrame. -- `lower_sym`: The symbol representing the lower confidence interval (default: `:q_025`). -- `upper_sym`: The symbol representing the upper confidence interval (default: `:q_975`). - -## Returns -- `plt_model`: The model panel plot. - -""" -function _figure_one_model_panel( - analysis_df, scenario, target, latentmodel; reference_time, true_gi_choice, - used_gi_choice, lower_sym = :q_025, upper_sym = :q_975) - model_plotting_data = analysis_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Used_GI_Mean.==used_gi_choice) |> - df -> @subset(df, - :Reference_Time.==reference_time) |> - df -> @subset(df, :Scenario.==scenario) |> - df -> @subset(df, :Target.==target) |> - df -> @subset(df, - :Latent_Model.==latentmodel) |> - data - - plt_model = model_plotting_data * - mapping(:target_times => "T", :q_5 => "Process values", - color = :IGP_Model => "Infection generating process") * - mapping(lower = lower_sym, upper = upper_sym) * visual(LinesFill) - - return plt_model -end - -""" -Internal method for creating a truth data panel plot for a given scenario and - target using the provided truth data. - -## Arguments -- `truth_df`: DataFrame containing the truth data. -- `scenario`: Scenario to plot. -- `target`: Target to plot. -- `true_gi_choice`: True GI choice to filter the data. - -## Returns -- `plt_truth`: Plot object representing the truth data panel. +# Arguments +- `ax`: The axis to plot on. +- `truth`: The DataFrame containing the truth data. +- `scenario`: The scenario to filter the truth data by. +- `target`: The target to filter the truth data by. +- `true_gi_choice`: The true generation interval choice to filter the truth data by (default is 2.0). +- `color`: The color of the scatter plot (default is :black). """ -function _figure_one_truth_data_panel(truth_df, scenario, target; true_gi_choice) - truth_plotting_data = truth_df |> - df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> - df -> @subset(df, :Scenario.==scenario) |> - df -> @subset(df, :Target.==target) |> data - plt_truth = truth_plotting_data * - mapping( - :target_times => "T", :target_values => "values", color = :IGP_Model) * - visual(Scatter) - return plt_truth +function _plot_truth!(ax, truth, scenario, target; true_gi_choice = 2.0, color = :black) + pred = truth |> + df -> @subset(df, :True_GI_Mean.==true_gi_choice) |> + df -> @subset(df, :Scenario.==scenario) |> + df -> @subset(df, :Target.==target) + x = pred[!, "target_times"] + y = pred[!, "target_values"] + scatter!(ax, x, y, color = color, label = "Data") + + return nothing end """ -Create figure one with multiple panels showing the analysis results and truth data for different scenarios and targets. +Generate a figure with multiple subplots showing predictions and truth data for different + scenarios and targets. # Arguments -- `truth_df::DataFrame`: The truth data as a DataFrame. -- `analysis_df::DataFrame`: The analysis data as a DataFrame. -- `latent_model::AbstractString`: The latent model to use for the infection generating process. -- `scenario_dict::Dict{AbstractString, Scenario}`: A dictionary mapping scenario names to Scenario objects. -- `target_dict::Dict{AbstractString, Target}`: A dictionary mapping target names to Target objects. -- `latent_model_dict::Dict{AbstractString, LatentModel}`: A dictionary mapping latent model names to LatentModel objects. - -# Optional Arguments -- `fig_kws::NamedTuple`: Keyword arguments for the Figure object. -- `true_gi_choice::Float64`: The true value of the infection generating process. -- `used_gi_choice::Float64`: The value of the infection generating process used in the analysis. -- `legend_title::AbstractString`: The title of the legend. -- `targets::Vector{AbstractString}`: The names of the targets to include in the figure. -- `scenarios::Vector{AbstractString}`: The names of the scenarios to include in the figure. - -# Returns -- `fig::Figure`: The figure object containing the panels. - -# Example -This assumes that the user already has the necessary dataframes `truth_df` and `analysis_df` loaded. - -```julia -using EpiAwarePipeline -# Define scenario titles and reference times for figure 1 -scenario_dict = Dict( - "measures_outbreak" => (title = "Outbreak with measures", T = 28), - "smooth_outbreak" => (title = "Outbreak no measures", T = 35), - "smooth_endemic" => (title = "Smooth endemic", T = 35), - "rough_endemic" => (title = "Rough endemic", T = 35) -) - -target_dict = Dict( - "log_I_t" => (title = "log(Incidence)", ylims = (3.5, 6)), - "rt" => (title = "Exp. growth rate", ylims = (-0.1, 0.1)), - "Rt" => (title = "Reproductive number", ylims = (-0.1, 3)) -) - -latent_model_dict = Dict( - "wkly_rw" => (title = "Random walk",), - "wkly_ar" => (title = "AR(1)",), - "wkly_diff_ar" => (title = "Diff. AR(1)",) -) - -fig1 = figureone( - truth_df, analysis_df, "wkly_ar", scenario_dict, target_dict, latent_model_dict) -``` +- `prediction_df::DataFrame`: DataFrame containing prediction data. +- `truth_data_df::DataFrame`: DataFrame containing truth data. +- `scenarios::Vector{String}`: List of scenario names. +- `targets::Vector{String}`: List of target names. +- `scenario_dict::Dict{String, Scenario}`: Dictionary mapping scenario names to scenario objects. +- `target_dict::Dict{String, Target}`: Dictionary mapping target names to target objects. +- `latent_model_dict::Dict{String, LatentModel}`: Dictionary mapping latent model names to latent model objects. +- `latent_model::String`: Name of the latent model to use (default: "ar"). +- `igps::Vector{String}`: List of infection generating processes (default: ["DirectInfections", "ExpGrowthRate", "Renewal"]). +- `true_gi_choice::Float64`: True generation interval choice (default: 2.0). +- `used_gi_choice::Float64`: Used generation interval choice (default: 2.0). +- `data_color::Symbol`: Color for the truth data (default: :black). +- `colors::Vector{Symbol}`: Colors for the predictions (default: [:red, :blue, :green]). +- `iqr_alpha::Float64`: Alpha value for the interquartile range shading (default: 0.3). """ function figureone( - truth_df, analysis_df, latent_model, scenario_dict, target_dict, - latent_model_dict; fig_kws = (; size = (1000, 1500)), - true_gi_choice = 10.0, used_gi_choice = 10.0, - legend_title = "Infection generating\n process", - targets = ["log_I_t", "rt", "Rt"], - scenarios = [ - "measures_outbreak", "smooth_outbreak", "smooth_endemic", "rough_endemic"]) - # Perform checks on the dataframes - _dataframe_checks(truth_df, analysis_df, scenario_dict) - latent_models = analysis_df.Latent_Model |> unique - @assert latent_model in latent_models "The latent model is not in the analysis data" - @assert latent_model in keys(latent_model_dict) "The latent model is not in the latent_model_dict dictionary" - @assert all([target in keys(target_dict) for target in targets]) "Not all targets are in the target dictionary" - @assert all([scenario in keys(scenario_dict) for scenario in scenarios]) "Not all scenarios are in the scenario dictionary" - - # Treat the truth data as a Latent model option - truth_df[!, "IGP_Model"] .= "Truth data" - - plt_truth_mat = [_figure_one_truth_data_panel( - truth_df, scenario, target; true_gi_choice) - for scenario in keys(scenario_dict), target in targets] - - plt_analysis_mat = [_figure_one_model_panel( - analysis_df, scenario, target, latent_model; - reference_time = scenario_dict[scenario].T, - true_gi_choice, used_gi_choice) - for scenario in keys(scenario_dict), target in targets] - - fig = Figure(; fig_kws...) - leg = nothing - for (i, scenario) in enumerate(scenarios) - for (j, target) in enumerate(targets) - sf = fig[i, j] - V = mapping([scenario_dict[scenario].T]) * - visual(VLines, color = :red, linewidth = 3) - - ag = draw!( - sf, plt_analysis_mat[i, j] + plt_truth_mat[i, j] + V, - axis = (; limits = (nothing, target_dict[target].ylims))) - leg = AlgebraOfGraphics.compute_legend(ag) - i == 1 && - Label(sf[0, 1], target_dict[target].title, fontsize = 22, font = :bold) - j == 3 && Label(sf[1, 2], scenario_dict[scenario].title, - fontsize = 18, font = :bold, rotation = -pi / 2) + prediction_df, truth_data_df, scenarios, targets; scenario_dict, target_dict, latent_model_dict, + latent_model = "ar", igps = ["DirectInfections", "ExpGrowthRate", "Renewal"], + true_gi_choice = 2.0, used_gi_choice = 2.0, data_color = :black, + colors = [:red, :blue, :green], iqr_alpha = 0.3) + fig = Figure(; size = (1000, 800)) + axs = mapreduce(hcat, enumerate(targets)) do (i, target) + map(enumerate(scenarios)) do (j, scenario) + ax = Axis(fig[i, j]) + _plot_predictions!( + ax, prediction_df, scenario, target, scenario_dict[scenario].T, + latent_model; true_gi_choice, used_gi_choice, colors, iqr_alpha, igps) + _plot_truth!( + ax, truth_data_df, scenario, target; true_gi_choice, color = data_color) + vlines!(ax, [scenario_dict[scenario].T], color = data_color, + linewidth = 3, label = "Horizon") + if i == 1 + ax.title = scenario_dict[scenario].title + end + if i == 3 + ax.xlabel = "Time" + end + if j == 1 + ax.ylabel = target_dict[target].title + end + ax end end - Label(fig[:, 0], "Process values", fontsize = 28, font = :bold, rotation = pi / 2) - Label(fig[5, 3], - "Latent model\n for infection\n generating\n process:\n$(latent_model_dict[latent_model].title)", - fontsize = 18, - font = :bold) - - _leg = (leg[1], leg[2], [legend_title]) - Legend(fig[5, 2], _leg...) - + leg = Legend(fig[length(targets) + 1, 1:3], last(axs), "Infection generating process"; + orientation = :horizontal, tellwidth = false, framevisible = false) + lab = Label(fig[length(targets) + 1, length(scenarios)], + "Latent model for \n infection generating\n process: $(latent_model_dict[latent_model].title)"; + tellwidth = false, + fontsize = 18) resize_to_layout!(fig) - return fig + fig end