Skip to content

Commit

Permalink
Use forecasttools::categorize_prism in `plot_category_pointinterval…
Browse files Browse the repository at this point in the history
…s.R` (#257)

* Update plot_category_pointintervals to use forecasttools::categorize_prism

* Remove superfluous arg

* Remove old category read-in from main

* Further cleanup
  • Loading branch information
dylanhmorris authored Dec 18, 2024
1 parent 7e70d44 commit c7f6c1e
Showing 1 changed file with 11 additions and 99 deletions.
110 changes: 11 additions & 99 deletions pipelines/plot_category_pointintervals.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,20 @@ library(dplyr)
library(argparser)


categorize_vec <- function(values, break_sets, label_sets) {
return(purrr::pmap_vec(
list(
x = values,
breaks = break_sets,
labels = label_sets,
include.lowest = TRUE,
order = TRUE,
right = TRUE
),
cut
))
}



with_category_cutpoints <- function(df,
disease,
categories) {
with_cutpoints <- df |>
mutate(disease = !!disease) |>
inner_join(categories, by = c("location", "disease"))
return(with_cutpoints)
}


to_categorized_iqr <- function(hub_table,
disease,
categories,
.keep = FALSE) {
disease) {
result <- hub_table |>
pivot_hubverse_quantiles_wider() |>
with_category_cutpoints(
disease = disease,
categories = categories
) |>
mutate(
category_point = categorize_vec(
.data$point,
.data$bin_breaks,
.data$bin_names
),
category_lower = categorize_vec(
.data$lower,
.data$bin_breaks,
.data$bin_names
),
category_upper = categorize_vec(
.data$upper,
.data$bin_breaks,
.data$bin_names
),
across(c("point", "lower", "upper"),
~ forecasttools::categorize_prism(
.x,
.data$location,
!!disease
),
.names = "category_{.col}"
)
)

if (!.keep) {
result <- result |> select(-c(bin_breaks, bin_names))
}

return(result)
}

Expand Down Expand Up @@ -118,51 +74,11 @@ main <- function(influenza_table_path,
categories_path,
output_path,
...) {
categories <- arrow::read_parquet(categories_path) |>
transmute(
disease,
location = state_abb,
prop_lower_bound = 0,
prop_low = perc_level_low / 100,
prop_moderate = perc_level_moderate / 100,
prop_high = perc_level_high / 100,
prop_very_high = perc_level_very_high / 100,
prop_upper_bound = 1,
very_low_name = "Very Low",
low_name = "Low",
moderate_name = "Moderate",
high_name = "High",
very_high_name = "Very High"
) |>
tidyr::nest(
bin_breaks = c(
prop_lower_bound,
prop_low,
prop_moderate,
prop_high,
prop_very_high,
prop_upper_bound
),
bin_names = c(
very_low_name,
low_name,
moderate_name,
high_name,
very_high_name
)
)

flu_dat <- readr::read_tsv(influenza_table_path) |>
to_categorized_iqr(
"Influenza",
categories
)
to_categorized_iqr("Influenza")

covid_dat <- readr::read_tsv(covid_table_path) |>
to_categorized_iqr(
"COVID-19",
categories
)
to_categorized_iqr("COVID-19")

plots <- list(
flu_plot_1wk = flu_dat |>
Expand Down Expand Up @@ -199,10 +115,6 @@ p <- arg_parser("Create a pointinterval plot of forecasts") |>
"covid_table_path",
help = "Path to a hubverse format forecast table for COVID-19."
) |>
add_argument(
"categories_path",
help = "Path to a parquet file containing PRISM category cutpoints."
) |>
add_argument(
"output_path",
help = "Path to save the output plots, as a single PDF"
Expand Down

0 comments on commit c7f6c1e

Please sign in to comment.