diff --git a/_targets_eval_postprocessing.R b/_targets_eval_postprocessing.R index 0fa27e43..c45d3067 100644 --- a/_targets_eval_postprocessing.R +++ b/_targets_eval_postprocessing.R @@ -1464,6 +1464,22 @@ benchmarks <- list( wwinference_version = eval_config$wwinference_version, overwrite_benchmark = eval_config$overwrite_benchmark ) + ), + tar_target( + name = plot_benchmark_by_loc, + command = plot_benchmarks( + grouping_var = "location", + benchmark_scope = "all_forecasts", + benchmark_dir = benchmark_config$benchmark_dir + ) + ), + tar_target( + name = plot_benchmark_by_forecast_date, + command = plot_benchmarks( + grouping_var = "forecast_date", + benchmark_scope = "all_forecasts", + benchmark_dir = benchmark_config$benchmark_dir + ) ) ) diff --git a/_targets_subset_benchmarking.R b/_targets_subset_benchmarking.R index b903a53b..d83b7d84 100644 --- a/_targets_subset_benchmarking.R +++ b/_targets_subset_benchmarking.R @@ -82,7 +82,7 @@ combined_targets <- list( # Benchmarking---------------------------------------------------------- benchmarks <- list( tar_target( - name = write_benchmark_table_subset_run, + name = benchmark_table_subset_run, command = benchmark_performance( ww_scores = ww_scores, hosp_scores = hosp_scores, @@ -91,6 +91,24 @@ benchmarks <- list( wwinference_version = benchmark_config$wwinference_version, overwrite_benchmark = benchmark_config$overwrite_benchmark ) + ), + tar_target( + name = plot_benchmark_by_loc, + command = plot_benchmarks( + grouping_var = "location", + benchmark_scope = "subset_forecasts", + benchmark_dir = benchmark_config$benchmark_dir, + scores_list = benchmark_table_subset_run + ) + ), + tar_target( + name = plot_benchmark_by_forecast_date, + command = plot_benchmarks( + grouping_var = "forecast_date", + benchmark_scope = "subset_forecasts", + benchmark_dir = benchmark_config$benchmark_dir, + scores_list = benchmark_table_subset_run + ) ) ) diff --git a/command_line_eval.R b/command_line_eval.R index 97fead70..a778786b 100644 --- a/command_line_eval.R +++ b/command_line_eval.R @@ -5,7 +5,6 @@ # in parallel on Azure batch library(argparser, quietly = TRUE) -library(cfaforecastrenewalww) library(wweval) library(ggplot2) # some functions from plots.R complain about aes() function not existing if we don't load ggplot2 diff --git a/input/params.toml b/input/params.toml index c628e4f4..28951d1a 100644 --- a/input/params.toml +++ b/input/params.toml @@ -41,9 +41,10 @@ offset_ref_initial_exp_growth_rate_prior_sd = 0.025 autoreg_p_hosp_a = 1 # shape1 parameter of autoreg term on IHR(t) trend autoreg_p_hosp_b = 100 # shape2 parameter of autoreg term on IHR(t) trend -eta_sd_sd = 0.01 -infection_feedback_prior_logmean = 6.37408 # log(mode) + q^2 mode = 500, q = 0.4 -infection_feedback_prior_logsd = 0.4 +eta_sd_sd = 0.0097 # this is the mean, sd is 0.0097 +eta_sd_mean = 0.0278 +infection_feedback_prior_logmean = 4.498 # mode=100 log(mode) + q^2 mode = 100, q = 1 +infection_feedback_prior_logsd = 0.636 # roughly 16 to 600 in 95% ci [hospital_admission_observation_process] # Hospitalization parameters (informative priors) diff --git a/output/benchmarking/latest_subset_forecasts_by_forecast_date.tsv b/output/benchmarking/latest_subset_forecasts_by_forecast_date.tsv index 58845827..5fb90451 100644 --- a/output/benchmarking/latest_subset_forecasts_by_forecast_date.tsv +++ b/output/benchmarking/latest_subset_forecasts_by_forecast_date.tsv @@ -1,7 +1,7 @@ forecast_date crps_hosp crps_ww bias_hosp bias_ww ae_hosp ae_ww wweval_commit_hash wwinference_version time_stamp -2023-10-16 5.145227727175137 4.405852136963368 0.48245657894736843 0.45490394736842105 6.950701593150161 6.041583425908238 d23b28a v0.1.0 2024-10-24T17:48:21Z -2023-11-13 4.66700940996615 4.019961671884518 -0.3123618421052632 -0.22545263157894738 6.145553133172695 5.324888683286438 d23b28a v0.1.0 2024-10-24T17:48:21Z -2023-12-11 8.859086840566519 8.837148573389927 -0.5840236842105263 -0.511771052631579 11.50992521124024 11.344388151022924 d23b28a v0.1.0 2024-10-24T17:48:21Z -2024-01-08 17.39089396390878 15.609039964313409 0.25387105263157894 0.3958921052631579 22.182123575526884 20.51478614874109 d23b28a v0.1.0 2024-10-24T17:48:21Z -2024-02-05 3.324198357881111 5.079010312322833 0.26249342105263157 0.3663105263157895 4.572098587350526 7.344139936498404 d23b28a v0.1.0 2024-10-24T17:48:21Z -2024-03-04 3.279328388379696 4.706364986307469 0.49960394736842106 0.6323973684210527 4.718323317619102 6.606204242479576 d23b28a v0.1.0 2024-10-24T17:48:21Z +2023-10-16 5.874998928099601 4.842755313016674 0.4281315789473684 0.40362236842105265 8.202210281292677 6.916940992684773 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2023-11-13 5.514058896168103 4.2595719435319985 -0.4940644736842105 -0.3491868421052632 7.648424162053213 5.968291847862045 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2023-12-11 5.367768349490639 5.48631763373991 -0.34426973684210527 -0.20431052631578947 8.157250251202527 7.768904286971368 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-01-08 35.16983309770466 35.53006450254746 0.5046828947368421 0.5802315789473684 44.28701944409863 45.39614861289537 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-02-05 2.773028914102106 3.711888477140153 -0.07066578947368421 0.15726973684210527 3.12398987017981 5.094208909464286 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-03-04 2.751103566063376 4.452629873235709 0.3015644736842105 0.4553171052631579 3.7821969793270878 6.373191104611754 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z diff --git a/output/benchmarking/latest_subset_forecasts_by_location.tsv b/output/benchmarking/latest_subset_forecasts_by_location.tsv index f5cc6cbd..908f2a2b 100644 --- a/output/benchmarking/latest_subset_forecasts_by_location.tsv +++ b/output/benchmarking/latest_subset_forecasts_by_location.tsv @@ -1,7 +1,7 @@ crps_hosp crps_ww bias_hosp bias_ww ae_hosp ae_ww location wweval_commit_hash wwinference_version time_stamp -1.0465392954054191 1.0437418123363529 -0.052574561403508774 -0.05236184210526316 1.430541441794062 1.4249102564297214 AK d23b28a v0.1.0 2024-10-24T17:48:21Z -11.624616908337584 11.001390757513024 0.23765679824561403 0.30562280701754385 15.2169931884645 14.921735155883097 MA d23b28a v0.1.0 2024-10-24T17:48:21Z -1.9743211076933298 1.7647838214371094 0.10642543859649123 0.11702741228070176 2.6690277913463327 2.360288987441267 NH d23b28a v0.1.0 2024-10-24T17:48:21Z -17.371192144075163 17.18502224109214 0.16711842105263158 0.24489583333333334 22.505788770781994 22.54254318981423 NJ d23b28a v0.1.0 2024-10-24T17:48:21Z -3.5381177843863316 4.552876071939314 0.0430734649122807 0.31171600877192984 4.9099199893294525 6.397181233712242 WA d23b28a v0.1.0 2024-10-24T17:48:21Z -7.110957447979565 7.109562940863587 0.10033991228070176 0.18538004385964912 9.346454236343268 9.529331764656112 all d23b28a v0.1.0 2024-10-24T17:48:21Z +0.9048110671853405 0.8654847567756644 -0.07057236842105263 -0.07581030701754386 1.2465896546056678 1.2173595112401607 AK 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +14.762354885967529 12.681244008639368 0.18578947368421053 0.30892434210526315 19.250396069656812 16.835413776804067 MA 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2.100055219513157 1.6857065406611689 -0.036726973684210525 0.038880482456140356 2.984162257895195 2.31003750944109 NH 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +26.360153847097493 25.59449289371047 0.1790548245614035 0.2609703947368421 33.681565511153316 33.56127504975967 NJ 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +3.7482847732602225 7.7424282528899155 0.013604166666666669 0.3361546052631579 5.504862330150637 10.673985614829668 WA 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +9.575131958604748 9.713871290535318 0.05422982456140351 0.17382390350877194 12.533515164692325 12.919614292414932 all 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z diff --git a/output/benchmarking/plots/subset_forecasts_by_forecast_date.png b/output/benchmarking/plots/subset_forecasts_by_forecast_date.png new file mode 100644 index 00000000..b4cd2707 Binary files /dev/null and b/output/benchmarking/plots/subset_forecasts_by_forecast_date.png differ diff --git a/output/benchmarking/plots/subset_forecasts_by_location.png b/output/benchmarking/plots/subset_forecasts_by_location.png new file mode 100644 index 00000000..93ebde67 Binary files /dev/null and b/output/benchmarking/plots/subset_forecasts_by_location.png differ diff --git a/output/benchmarking/plots/subset_forecasts_overall.png b/output/benchmarking/plots/subset_forecasts_overall.png new file mode 100644 index 00000000..0ab74b20 Binary files /dev/null and b/output/benchmarking/plots/subset_forecasts_overall.png differ diff --git a/output/benchmarking/subset_forecasts_by_forecast_date.tsv b/output/benchmarking/subset_forecasts_by_forecast_date.tsv index 58845827..65eeecf4 100644 --- a/output/benchmarking/subset_forecasts_by_forecast_date.tsv +++ b/output/benchmarking/subset_forecasts_by_forecast_date.tsv @@ -1,6 +1,18 @@ forecast_date crps_hosp crps_ww bias_hosp bias_ww ae_hosp ae_ww wweval_commit_hash wwinference_version time_stamp -2023-10-16 5.145227727175137 4.405852136963368 0.48245657894736843 0.45490394736842105 6.950701593150161 6.041583425908238 d23b28a v0.1.0 2024-10-24T17:48:21Z -2023-11-13 4.66700940996615 4.019961671884518 -0.3123618421052632 -0.22545263157894738 6.145553133172695 5.324888683286438 d23b28a v0.1.0 2024-10-24T17:48:21Z +2023-10-16 5.874998928099601 4.842755313016674 0.4281315789473684 0.40362236842105265 8.202210281292677 6.916940992684773 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2023-11-13 5.514058896168103 4.2595719435319985 -0.4940644736842105 -0.3491868421052632 7.648424162053213 5.968291847862045 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2023-12-11 5.367768349490639 5.48631763373991 -0.34426973684210527 -0.20431052631578947 8.157250251202527 7.768904286971368 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-01-08 35.16983309770466 35.53006450254746 0.5046828947368421 0.5802315789473684 44.28701944409863 45.39614861289537 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-02-05 2.773028914102106 3.711888477140153 -0.07066578947368421 0.15726973684210527 3.12398987017981 5.094208909464286 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2024-03-04 2.751103566063376 4.452629873235709 0.3015644736842105 0.4553171052631579 3.7821969793270878 6.373191104611754 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2023-10-16 9.243950678640694 6.806804188258501 0.5824578947368421 0.4910565789473684 12.075811587773888 9.080835846088256 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2023-11-13 6.798783532681564 4.377860426544979 -0.572275 -0.32701842105263157 8.806163742433748 5.917972294575069 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2023-12-11 5.578345847009017 5.404499318529154 -0.22773947368421052 -0.12964736842105262 7.0379402652994925 6.487141426491198 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2024-01-08 73.02250974305282 64.10560452659152 0.5030894736842105 0.5657434210526315 90.10392730630215 79.63998299207216 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2024-02-05 3.2543113736957956 4.331650777414377 -0.16031315789473685 0.14943684210526317 4.889963029507771 5.687676001142328 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2024-03-04 2.4436722259876116 4.390295509081186 0.2651552631578947 0.5207526315789474 3.274852617671909 6.112645792027634 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2023-10-16 5.145227727175137 4.405852136963368 0.4824565789473684 0.45490394736842105 6.950701593150161 6.041583425908238 d23b28a v0.1.0 2024-10-24T17:48:21Z +2023-11-13 4.66700940996615 4.019961671884518 -0.3123618421052632 -0.22545263157894735 6.145553133172695 5.324888683286438 d23b28a v0.1.0 2024-10-24T17:48:21Z 2023-12-11 8.859086840566519 8.837148573389927 -0.5840236842105263 -0.511771052631579 11.50992521124024 11.344388151022924 d23b28a v0.1.0 2024-10-24T17:48:21Z 2024-01-08 17.39089396390878 15.609039964313409 0.25387105263157894 0.3958921052631579 22.182123575526884 20.51478614874109 d23b28a v0.1.0 2024-10-24T17:48:21Z 2024-02-05 3.324198357881111 5.079010312322833 0.26249342105263157 0.3663105263157895 4.572098587350526 7.344139936498404 d23b28a v0.1.0 2024-10-24T17:48:21Z diff --git a/output/benchmarking/subset_forecasts_by_location.tsv b/output/benchmarking/subset_forecasts_by_location.tsv index f5cc6cbd..34dc516f 100644 --- a/output/benchmarking/subset_forecasts_by_location.tsv +++ b/output/benchmarking/subset_forecasts_by_location.tsv @@ -1,7 +1,19 @@ crps_hosp crps_ww bias_hosp bias_ww ae_hosp ae_ww location wweval_commit_hash wwinference_version time_stamp -1.0465392954054191 1.0437418123363529 -0.052574561403508774 -0.05236184210526316 1.430541441794062 1.4249102564297214 AK d23b28a v0.1.0 2024-10-24T17:48:21Z +0.9048110671853405 0.8654847567756644 -0.07057236842105263 -0.07581030701754386 1.2465896546056678 1.2173595112401607 AK 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +14.762354885967529 12.681244008639368 0.18578947368421053 0.30892434210526315 19.250396069656812 16.835413776804067 MA 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +2.100055219513157 1.6857065406611689 -0.036726973684210525 0.038880482456140356 2.984162257895195 2.31003750944109 NH 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +26.360153847097493 25.59449289371047 0.1790548245614035 0.2609703947368421 33.681565511153316 33.56127504975967 NJ 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +3.7482847732602225 7.7424282528899155 0.013604166666666669 0.3361546052631579 5.504862330150637 10.673985614829668 WA 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +9.575131958604748 9.713871290535318 0.05422982456140351 0.17382390350877194 12.533515164692325 12.919614292414932 all 11a9898 227-inf-feedback-posterior-mod-eta 2024-10-30T19:46:57Z +1.0376521512900907 1.0118669366748054 -0.03699122807017543 -0.00765679824561403 1.4245526802625763 1.4264265157766278 AK 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +24.626800317489625 19.49484775911998 0.2060657894736842 0.4117872807017544 30.689131156635455 24.78340014717931 MA 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +2.9839338399865656 1.8872799196790224 -0.01010087719298246 0.04486293859649123 4.248055546108953 2.2765310519617756 NH 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +50.50136383077716 41.39684785751621 0.1817642543859649 0.2782510964912281 62.26077007486285 51.60308220739496 NJ 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +4.468227694679483 10.723086482359754 -0.01542543859649122 0.33135855263157893 6.534705999620968 14.01577203801786 WA 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +16.723595566844583 14.902785791069954 0.0650625 0.2117206140350877 21.03144309149816 18.821042392066108 all 51ddee5 227-inf-feedback 2024-10-25T17:16:51Z +1.0465392954054191 1.0437418123363529 -0.05257456140350877 -0.05236184210526316 1.430541441794062 1.4249102564297214 AK d23b28a v0.1.0 2024-10-24T17:48:21Z 11.624616908337584 11.001390757513024 0.23765679824561403 0.30562280701754385 15.2169931884645 14.921735155883097 MA d23b28a v0.1.0 2024-10-24T17:48:21Z -1.9743211076933298 1.7647838214371094 0.10642543859649123 0.11702741228070176 2.6690277913463327 2.360288987441267 NH d23b28a v0.1.0 2024-10-24T17:48:21Z -17.371192144075163 17.18502224109214 0.16711842105263158 0.24489583333333334 22.505788770781994 22.54254318981423 NJ d23b28a v0.1.0 2024-10-24T17:48:21Z +1.9743211076933296 1.7647838214371094 0.10642543859649124 0.11702741228070176 2.6690277913463327 2.360288987441267 NH d23b28a v0.1.0 2024-10-24T17:48:21Z +17.371192144075163 17.18502224109214 0.16711842105263158 0.24489583333333337 22.505788770781994 22.54254318981423 NJ d23b28a v0.1.0 2024-10-24T17:48:21Z 3.5381177843863316 4.552876071939314 0.0430734649122807 0.31171600877192984 4.9099199893294525 6.397181233712242 WA d23b28a v0.1.0 2024-10-24T17:48:21Z 7.110957447979565 7.109562940863587 0.10033991228070176 0.18538004385964912 9.346454236343268 9.529331764656112 all d23b28a v0.1.0 2024-10-24T17:48:21Z diff --git a/scratch/get_empirical_posterior_vals.R b/scratch/get_empirical_posterior_vals.R new file mode 100644 index 00000000..92479d76 --- /dev/null +++ b/scratch/get_empirical_posterior_vals.R @@ -0,0 +1,72 @@ +# Quick estimate of posterior parameters + +benchmark_config <- yaml::read_yaml(file.path( + "input", "config", + "eval", "benchmark_config.yaml" +)) + +vars <- c("eta_sd", "inf_feedback") +eta_sd_draws <- tibble::tibble() +inf_feedback_draws <- tibble::tibble() + +for (i in seq_along(benchmark_config$forecast_date_hosp)) { + this_location <- benchmark_config$location_hosp[i] + this_forecast_date <- benchmark_config$forecast_date_hosp[i] + this_scenario <- "no_wastewater" + for (j in seq_along(vars)) { + fp_var <- wweval::get_filepath(benchmark_config$output_dir, + scenario = this_scenario, + forecast_date = this_forecast_date, + model_type = "hosp", + location = this_location, + output_type = vars[j], + file_extension = "tsv" + ) + + these_var_draws <- readr::read_tsv(fp_var) + var_draws <- these_var_draws |> + dplyr::mutate( + location = this_location, + forecast_date = this_forecast_date + ) + if (vars[j] == "eta_sd") { + eta_sd_draws <- dplyr::bind_rows(eta_sd_draws, var_draws) + } + if (vars[j] == "inf_feedback") { + inf_feedback_draws <- dplyr::bind_rows( + inf_feedback_draws, + var_draws + ) + } + } # end loop over vars +} # end loop over forecast date-locations + +# Get empirical mean, sd, logmean, and logsd------------------------------ +## eta_sd--------------------------------------------------------------- +mean_eta_sd <- mean(eta_sd_draws$eta_sd) +sd_eta_sd <- sd(eta_sd_draws$eta_sd) + +message("Empirical mean of RW step size across 5 locations: ", mean_eta_sd) +message("Empirical sd of RW step size across 5 locations: ", sd_eta_sd) + +## inf_feedback---------------------------------------------------------- +logmean_inf_feedback <- mean(log(inf_feedback_draws$infection_feedback)) +logsd_inf_feedback <- sd(log(inf_feedback_draws$infection_feedback)) + + +message( + "Empirical logmean of infection feedback across 5 locations: ", + logmean_inf_feedback +) +message( + "Empirical logsd of infection feedback across 5 locations: ", + logsd_inf_feedback +) + +posterior_params <- list( + mean_eta_sd = mean_eta_sd, + sd_eta_sd = sd_eta_sd, + logmean_inf_feedback = logmean_inf_feedback, + logsd_inf_feedback = logsd_inf_feedback +) +yaml::write_yaml(posterior_params, "output/posterior_params.yaml") diff --git a/src/setup_eval.R b/src/setup_eval.R index a7dcd1d5..d92e3d5b 100644 --- a/src/setup_eval.R +++ b/src/setup_eval.R @@ -32,5 +32,7 @@ write_eval_config( overwrite_summary_table = FALSE, # Set as TRUE if trying to get a baseline # score for all locations one forecast date overwrite_benchmark = TRUE, # Set as TRUE if want to save outputs of - wwinference_version = "v0.1.0" + + # benchmarking in directory + wwinference_version = "227-inf-feedback-posterior-mod-eta" ) diff --git a/src/setup_subset_benchmarking.R b/src/setup_subset_benchmarking.R index a3e7f8e1..35c351a8 100644 --- a/src/setup_subset_benchmarking.R +++ b/src/setup_subset_benchmarking.R @@ -29,5 +29,5 @@ write_eval_config( overwrite_benchmark = TRUE, # Set as TRUE if want to save outputs of # benchmarking in directory, name_of_config = "benchmark_config", - wwinference_version = "v0.1.0" + wwinference_version = "227-inf-feedback-posterior-mod-eta" ) diff --git a/src/write_eval_config.R b/src/write_eval_config.R index 349a76ba..44ad947c 100644 --- a/src/write_eval_config.R +++ b/src/write_eval_config.R @@ -6,9 +6,10 @@ #' @param scenatios the scenarios (which will pertain to site ids) to #' run the model on #' @param config_dir the directory where we want to save the config file +#' @param scenario_dir the directory where the files defining scenarios +#' (default `.tsv` format) are located #' @param benchmark_dir the directory where to save the benchmarked performance #' for this run -#' @param ms_fig_dir the directory to save the manuscript figures in #' @param eval_date the data of the evaluation dataset, in ISO YYYY-MM-DD format #' @param overwrite_summary_table Boolean indicating whether or not to overwrite #' internal summary table @@ -191,6 +192,7 @@ write_eval_config <- function(locations, forecast_dates, adapt_delta = adapt_delta, max_treedepth = max_treedepth, seed = seed, + name_of_config = name_of_config, # Input delay distributions generation_interval = generation_interval, infection_feedback_pmf = generation_interval, diff --git a/wweval/DESCRIPTION b/wweval/DESCRIPTION index 1f1a53f0..b10bc6b8 100644 --- a/wweval/DESCRIPTION +++ b/wweval/DESCRIPTION @@ -90,5 +90,5 @@ Config/Needs/check: rcmdcheck, testthat RoxygenNote: 7.3.2 Remotes: stan-dev/cmdstanr, - wwinference=CDCgov/ww-inference-model@v0.1.0 + wwinference=CDCgov/ww-inference-model@227-inf-feedback-mod LazyData: true diff --git a/wweval/NAMESPACE b/wweval/NAMESPACE index c1d94a38..a5ca7d91 100644 --- a/wweval/NAMESPACE +++ b/wweval/NAMESPACE @@ -87,6 +87,7 @@ export(nhsn_soda_query) export(order_horizons) export(order_periods) export(order_phases) +export(plot_benchmarks) export(plot_components) export(plot_quantiles) export(pull_nhsn) diff --git a/wweval/R/benchmarking.R b/wweval/R/benchmarking.R index 2a59da6e..d1ac3379 100644 --- a/wweval/R/benchmarking.R +++ b/wweval/R/benchmarking.R @@ -93,6 +93,7 @@ benchmark_performance <- function(ww_scores, dplyr::select(colnames(overall_scores)) |> dplyr::bind_rows(overall_scores) + benchmarks <- list( scores_by_forecast_date = scores_by_forecast_date, scores_by_location = scores_by_location @@ -219,3 +220,109 @@ benchmark_performance <- function(ww_scores, return(benchmarks) } + + +#' Plot benchmark summaries +#' +#' @param grouping_var The variable to plot by, so in this case either +#' `location` or `forecast_date` +#' @param benchmark_scope The scope of the benchmarking (so in this case +#' either `subset_forecasts` or `all_forecasts`) +#' @param benchmark_dir The directory where the benchmark tables live +#' @param scores_list The list of tables of recent scores +#' @param score_to_plot Which of the scores saved in the benchmarking +#' tables to plot, options are `crps`, `bias`, and `ae`, defualt is `crps` +#' @param write_files Boolean indicating whether or not to save the plots +#' to disk, default is TRUE +#' +#' @return a ggplot object containing a bar chart colored by package/model +#' version and faceted by the grouping variable, with the height of the bar +#' indicating the scores +#' @export +plot_benchmarks <- function(grouping_var, + benchmark_scope, + benchmark_dir, + scores_list, + score_to_plot = "crps", + write_files = TRUE) { + # Load in table + fp <- glue::glue("{benchmark_dir}/{benchmark_scope}_by_{grouping_var}.tsv") + df <- readr::read_tsv(fp) + + # pivot_longer for plotting + df_long <- df |> + tidyr::pivot_longer( + cols = crps_hosp:ae_ww, + names_to = c("score_type", "model"), + names_pattern = "(.*)_(.*)", + values_to = "score" + ) + + if (grouping_var == "location") { + df_all <- df_long |> + dplyr::filter(location == "all") + df_long <- df_long |> + dplyr::filter(location != "all") + + p_all <- ggplot(df_all) + + geom_bar( + aes( + x = model, y = score, + fill = wwinference_version + ), + stat = "identity", + position = "dodge" + ) + + facet_wrap(~score_type, scales = "free_y") + + theme( + legend.position = "bottom", + panel.background = element_rect(fill = "white") + ) + + ggtitle("Overall performance benchmarking") + if (isTRUE(write_files)) { + ggsave( + filename = glue::glue( + "{benchmark_dir}/plots/{benchmark_scope}_overall.png" + ), + plot = p_all, + create.dir = TRUE + ) + } + } + + p <- ggplot(df_long |> + dplyr::filter(score_type == { + score_to_plot + })) + + geom_bar( + aes( + x = model, y = score, + fill = wwinference_version + ), + stat = "identity", + position = "dodge" + ) + + facet_wrap( + { + grouping_var + }, + scales = "free_y" + ) + + theme( + legend.position = "bottom", + panel.background = element_rect(fill = "white") + ) + + ylab("CRPS") + + if (isTRUE(write_files)) { + ggsave( + filename = glue::glue( + "{benchmark_dir}/plots/{benchmark_scope}_by_{grouping_var}.png" + ), + plot = p, + create.dir = TRUE + ) + } + + return(p) +} diff --git a/wweval/R/eval_post_process.R b/wweval/R/eval_post_process.R index afe0cd0a..57ec8fe1 100644 --- a/wweval/R/eval_post_process.R +++ b/wweval/R/eval_post_process.R @@ -63,6 +63,70 @@ eval_post_process_ww <- function(config_index, ) save_object("raw_flags", output_file_suffix) + # Save some plots of the posterior parameters--------------------- + inf_feedback <- ww_raw_draws |> + tidybayes::spread_draws(!!str2lang("infection_feedback")) |> + dplyr::mutate( + draw = .data$`.draw`, + ) |> + dplyr::select("infection_feedback", "draw") + + p_inf <- ggplot( + inf_feedback, + aes(x = infection_feedback) + ) + + geom_histogram() + + eta_sd <- ww_raw_draws |> + tidybayes::spread_draws(!!str2lang("eta_sd")) |> + dplyr::mutate( + draw = .data$`.draw`, + ) |> + dplyr::select("eta_sd", "draw") + + p_eta_sd <- ggplot( + eta_sd, + aes(x = eta_sd) + ) + + geom_histogram() + + ggsave(p_inf, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "inf_feedback.png" + ), + create.dir = TRUE + ) + ggsave(p_eta_sd, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "eta_sd.png" + ), + create.dir = TRUE + ) + + save_table( + data_to_save = eta_sd, + type_of_output = "eta_sd", + output_dir = output_dir, + scenario = scenario, + forecast_date = forecast_date, + model_type = "ww", + location = location + ) + save_table( + data_to_save = inf_feedback, + type_of_output = "inf_feedback", + output_dir = output_dir, + scenario = scenario, + forecast_date = forecast_date, + model_type = "ww", + location = location + ) + + # Make the data look like it did in wweval------------------------------- input_hosp_data_wweval <- input_hosp_data |> dplyr:::rename( @@ -160,12 +224,14 @@ eval_post_process_ww <- function(config_index, rate = "weekly" ) - ggsave(plot_growth_rates, filename = file.path( - output_dir, scenario, - forecast_date, "ww", location, - "plot_growth_rates.png" - )) - + ggsave(plot_growth_rates, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "plot_growth_rates.png" + ), + create.dir = TRUE + ) hosp_draws <- { if (!is.null(ww_fit_obj_wwinference$error)) { @@ -282,11 +348,14 @@ eval_post_process_ww <- function(config_index, } } - ggsave(plot_hosp_draws, filename = file.path( - output_dir, scenario, - forecast_date, "ww", location, - "plot_hosp_draws.png" - )) + ggsave(plot_hosp_draws, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "plot_hosp_draws.png" + ), + create.dir = TRUE + ) plot_hosp_t <- make_fig2_hosp_t( hosp_quantiles = full_hosp_quantiles, @@ -296,14 +365,48 @@ eval_post_process_ww <- function(config_index, ggtitle(glue::glue("{location} on {forecast_date}")) + theme_bw() - ggsave(plot_hosp_t, filename = file.path( - output_dir, scenario, - forecast_date, "ww", location, - "plot_hosp_t.png" - )) + ggsave(plot_hosp_t, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "plot_hosp_t.png" + ), + create.dir = TRUE + ) save_object("plot_hosp_draws", output_file_suffix) + # Plots of R(t)s + draws <- wwinference::get_draws(ww_fit_obj_wwinference, what = "all") + + plot_state_rt <- wwinference::get_plot_global_rt( + draws$global_rt, + forecast_date + ) + ggsave(plot_state_rt, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "plot_state_rt.png" + ), + create.dir = TRUE + ) + + plot_subpop_rt <- wwinference::get_plot_subpop_rt( + draws$subpop_rt, + forecast_date + ) + ggsave(plot_subpop_rt, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", location, + "plot_subpop_rt.png" + ), + create.dir = TRUE + ) + + + plot_ww_draws <- { if (is.null(ww_draws)) { NULL @@ -316,11 +419,14 @@ eval_post_process_ww <- function(config_index, } } - ggsave(plot_ww_draws, filename = file.path( - output_dir, scenario, - forecast_date, "ww", - location, "plot_ww_draws.png" - )) + ggsave(plot_ww_draws, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", + location, "plot_ww_draws.png" + ), + create.dir = TRUE + ) plot_ww_t <- make_fig2_ct( full_ww_quantiles, @@ -332,11 +438,14 @@ eval_post_process_ww <- function(config_index, ggtitle(glue::glue("{location} on {forecast_date}")) + theme_bw() - ggsave(plot_ww_t, filename = file.path( - output_dir, scenario, - forecast_date, "ww", - location, "plot_ww_t.png" - )) + ggsave(plot_ww_t, + filename = file.path( + output_dir, scenario, + forecast_date, "ww", + location, "plot_ww_t.png" + ), + create.dir = TRUE + ) save_object("plot_ww_draws", output_file_suffix) @@ -441,6 +550,72 @@ eval_post_process_hosp <- function(config_index, location = location ) + # Make plots of posterior params------------------- + + inf_feedback <- hosp_raw_draws |> + tidybayes::spread_draws(!!str2lang("infection_feedback")) |> + dplyr::mutate( + draw = .data$`.draw`, + ) |> + dplyr::select("infection_feedback", "draw") + + p_inf <- ggplot( + inf_feedback, + aes(x = infection_feedback) + ) + + geom_histogram() + + eta_sd <- hosp_raw_draws |> + tidybayes::spread_draws(!!str2lang("eta_sd")) |> + dplyr::mutate( + draw = .data$`.draw`, + ) |> + dplyr::select("eta_sd", "draw") + + p_eta_sd <- ggplot( + eta_sd, + aes(x = eta_sd) + ) + + geom_histogram() + + ggsave( + filename = file.path( + output_dir, scenario, + forecast_date, "hosp", location, + "inf_feedback.png" + ), + p_inf, + create.dir = TRUE + ) + ggsave( + filename = file.path( + output_dir, scenario, + forecast_date, "hosp", location, + "eta_sd.png" + ), + p_eta_sd, + create.dir = TRUE + ) + save_table( + data_to_save = eta_sd, + type_of_output = "eta_sd", + output_dir = output_dir, + scenario = scenario, + forecast_date = forecast_date, + model_type = "hosp", + location = location + ) + save_table( + data_to_save = inf_feedback, + type_of_output = "inf_feedback", + output_dir = output_dir, + scenario = scenario, + forecast_date = forecast_date, + model_type = "hosp", + location = location + ) + + # Get evaluation data from hospital admissions and wastewater # Join draws with flags + data and metadata @@ -504,21 +679,42 @@ eval_post_process_hosp <- function(config_index, forecast_date, "hosp", location, "plot_hosp_draws.png" ), - bg = "white" + bg = "white", + create.dir = TRUE ) plot_hosp_t <- make_fig2_hosp_t( hosp_quantiles = full_hosp_model_quantiles, loc_to_plot = location, - date_to_plot = forecast_date + date_to_plot = forecast_date, + n_calib_days = eval_config$calibration_time ) + ggtitle(glue::glue("{location} on {forecast_date}")) - ggsave(plot_hosp_t, filename = file.path( - output_dir, scenario, - forecast_date, "hosp", location, - "plot_hosp_t.png" - )) + ggsave(plot_hosp_t, + filename = file.path( + output_dir, scenario, + forecast_date, "hosp", location, + "plot_hosp_t.png" + ), + create.dir = TRUE + ) + + # Plots of R(t)s + draws <- wwinference::get_draws(hosp_fit_obj_wwinference, what = "global_rt") + + plot_state_rt <- wwinference::get_plot_global_rt( + draws$global_rt, + forecast_date + ) + ggsave(plot_state_rt, + filename = file.path( + output_dir, scenario, + forecast_date, "hosp", location, + "plot_state_rt.png" + ), + create.dir = TRUE + ) ## Score the hospital admissions only model------------------------- hosp_scores <- get_full_scores(hosp_model_hosp_draws, diff --git a/wweval/R/exclude_hosp_outliers.R b/wweval/R/exclude_hosp_outliers.R index 5aa97c81..839c7908 100644 --- a/wweval/R/exclude_hosp_outliers.R +++ b/wweval/R/exclude_hosp_outliers.R @@ -26,15 +26,14 @@ exclude_hosp_outliers <- function(raw_input_hosp_data, stopifnot("Only one location passed in" = length(loc) == 1) - exclusions <- table_of_exclusions |> - dplyr::filter( - location == loc, - forecast_date == forecast_date - ) - - if (nrow(exclusions) == 0) { + if (nrow(table_of_exclusions) == 0) { input_hosp_data <- raw_input_hosp_data } else { + exclusions <- table_of_exclusions |> + dplyr::filter( + location == loc, + forecast_date == forecast_date + ) dates_to_exclude <- exclusions |> dplyr::pull({{ col_name_dates_to_exclude }}) input_hosp_data <- raw_input_hosp_data |> diff --git a/wweval/R/ms_fig4.R b/wweval/R/ms_fig4.R index 53fca302..f0f67a1e 100644 --- a/wweval/R/ms_fig4.R +++ b/wweval/R/ms_fig4.R @@ -653,7 +653,7 @@ FGGG ) #+ plot_annotation(tag_levels = "A") #nolint not working fig4 - create_dir(fig_file_dir) + fs::dir_create(fig_file_dir) ggsave(fig4, filename = file.path(fig_file_dir, "fig4.png"), diff --git a/wweval/man/plot_benchmarks.Rd b/wweval/man/plot_benchmarks.Rd new file mode 100644 index 00000000..abb35e1a --- /dev/null +++ b/wweval/man/plot_benchmarks.Rd @@ -0,0 +1,40 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/benchmarking.R +\name{plot_benchmarks} +\alias{plot_benchmarks} +\title{Plot benchmark summaries} +\usage{ +plot_benchmarks( + grouping_var, + benchmark_scope, + benchmark_dir, + scores_list, + score_to_plot = "crps", + write_files = TRUE +) +} +\arguments{ +\item{grouping_var}{The variable to plot by, so in this case either +\code{location} or \code{forecast_date}} + +\item{benchmark_scope}{The scope of the benchmarking (so in this case +either \code{subset_forecasts} or \code{all_forecasts})} + +\item{benchmark_dir}{The directory where the benchmark tables live} + +\item{scores_list}{The list of tables of recent scores} + +\item{score_to_plot}{Which of the scores saved in the benchmarking +tables to plot, options are \code{crps}, \code{bias}, and \code{ae}, defualt is \code{crps}} + +\item{write_files}{Boolean indicating whether or not to save the plots +to disk, default is TRUE} +} +\value{ +a ggplot object containing a bar chart colored by package/model +version and faceted by the grouping variable, with the height of the bar +indicating the scores +} +\description{ +Plot benchmark summaries +}