diff --git a/pipelines/plot_category_pointintervals.R b/pipelines/plot_category_pointintervals.R index 7c56680..90fdf2a 100644 --- a/pipelines/plot_category_pointintervals.R +++ b/pipelines/plot_category_pointintervals.R @@ -4,64 +4,20 @@ library(dplyr) library(argparser) -categorize_vec <- function(values, break_sets, label_sets) { - return(purrr::pmap_vec( - list( - x = values, - breaks = break_sets, - labels = label_sets, - include.lowest = TRUE, - order = TRUE, - right = TRUE - ), - cut - )) -} - - - -with_category_cutpoints <- function(df, - disease, - categories) { - with_cutpoints <- df |> - mutate(disease = !!disease) |> - inner_join(categories, by = c("location", "disease")) - return(with_cutpoints) -} - - to_categorized_iqr <- function(hub_table, - disease, - categories, - .keep = FALSE) { + disease) { result <- hub_table |> pivot_hubverse_quantiles_wider() |> - with_category_cutpoints( - disease = disease, - categories = categories - ) |> mutate( - category_point = categorize_vec( - .data$point, - .data$bin_breaks, - .data$bin_names - ), - category_lower = categorize_vec( - .data$lower, - .data$bin_breaks, - .data$bin_names - ), - category_upper = categorize_vec( - .data$upper, - .data$bin_breaks, - .data$bin_names - ), + across(c("point", "lower", "upper"), + ~ forecasttools::categorize_prism( + .x, + .data$location, + !!disease + ), + .names = "category_{.col}" + ) ) - - if (!.keep) { - result <- result |> select(-c(bin_breaks, bin_names)) - } - return(result) } @@ -118,51 +74,11 @@ main <- function(influenza_table_path, categories_path, output_path, ...) { - categories <- arrow::read_parquet(categories_path) |> - transmute( - disease, - location = state_abb, - prop_lower_bound = 0, - prop_low = perc_level_low / 100, - prop_moderate = perc_level_moderate / 100, - prop_high = perc_level_high / 100, - prop_very_high = perc_level_very_high / 100, - prop_upper_bound = 1, - very_low_name = "Very Low", - low_name = "Low", - moderate_name = "Moderate", - high_name = "High", - very_high_name = "Very High" - ) |> - tidyr::nest( - bin_breaks = c( - prop_lower_bound, - prop_low, - prop_moderate, - prop_high, - prop_very_high, - prop_upper_bound - ), - bin_names = c( - very_low_name, - low_name, - moderate_name, - high_name, - very_high_name - ) - ) - flu_dat <- readr::read_tsv(influenza_table_path) |> - to_categorized_iqr( - "Influenza", - categories - ) + to_categorized_iqr("Influenza") covid_dat <- readr::read_tsv(covid_table_path) |> - to_categorized_iqr( - "COVID-19", - categories - ) + to_categorized_iqr("COVID-19") plots <- list( flu_plot_1wk = flu_dat |> @@ -199,10 +115,6 @@ p <- arg_parser("Create a pointinterval plot of forecasts") |> "covid_table_path", help = "Path to a hubverse format forecast table for COVID-19." ) |> - add_argument( - "categories_path", - help = "Path to a parquet file containing PRISM category cutpoints." - ) |> add_argument( "output_path", help = "Path to save the output plots, as a single PDF"