Skip to content

Commit

Permalink
Merge pull request #254 from cmu-delphi/v0.0.6-cleanup
Browse files Browse the repository at this point in the history
V0.0.6 cleanup
  • Loading branch information
dshemetov authored Oct 19, 2023
2 parents 8c72690 + 9206afa commit 64a820d
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 54 deletions.
4 changes: 2 additions & 2 deletions R/cdc_baseline_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,11 @@ cdc_baseline_forecaster <- function(
#' cdc_baseline_args_list(quantile_levels = c(.1, .3, .7, .9), n_training = 120)
cdc_baseline_args_list <- function(
data_frequency = "1 week",
aheads = 1:4,
aheads = 1:5,
n_training = Inf,
forecast_date = NULL,
quantile_levels = c(.01, .025, 1:19 / 20, .975, .99),
nsims = 1e3L,
nsims = 1e5L,
symmetrize = TRUE,
nonneg = TRUE,
quantile_by_key = "geo_value",
Expand Down
75 changes: 40 additions & 35 deletions R/flusight_hub_formatter.R
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
abbr_to_fips <- function(abbr) {
fi <- dplyr::left_join(
tibble::tibble(abbr = tolower(abbr)),
state_census,
by = "abbr"
) %>%
dplyr::mutate(fips = as.character(fips), fips = case_when(
fips == "0" ~ "US",
nchar(fips) < 2L ~ paste0("0", fips),
TRUE ~ fips
)) %>%
pull(.data$fips)
names(fi) <- NULL
fi
location_to_abbr <- function(location) {
dictionary <-
state_census %>%
dplyr::mutate(fips = sprintf("%02d", fips)) %>%
dplyr::transmute(
location = dplyr::case_match(fips, "00" ~ "US", .default = fips),
abbr
)
dictionary$abbr[match(location, dictionary$location)]
}

abbr_to_location <- function(abbr) {
dictionary <-
state_census %>%
dplyr::mutate(fips = sprintf("%02d", fips)) %>%
dplyr::transmute(
location = dplyr::case_match(fips, "00" ~ "US", .default = fips),
abbr
)
dictionary$location[match(abbr, dictionary$abbr)]
}




#' Format predictions for submission to FluSight forecast Hub
#'
#' This function converts predictions from any of the included forecasters into
Expand Down Expand Up @@ -47,22 +56,23 @@ abbr_to_fips <- function(abbr) {
#' @export
#'
#' @examples
#' library(dplyr)
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#' if (require(dplyr)) {
#' weekly_deaths <- case_death_rate_subset %>%
#' select(geo_value, time_value, death_rate) %>%
#' left_join(state_census %>% select(pop, abbr), by = c("geo_value" = "abbr")) %>%
#' mutate(deaths = pmax(death_rate / 1e5 * pop * 7, 0)) %>%
#' select(-pop, -death_rate) %>%
#' group_by(geo_value) %>%
#' epi_slide(~ sum(.$deaths), before = 6, new_col_name = "deaths") %>%
#' ungroup() %>%
#' filter(weekdays(time_value) == "Saturday")
#'
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' flusight_hub_formatter(cdc)
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile")
#' cdc <- cdc_baseline_forecaster(weekly_deaths, "deaths")
#' flusight_hub_formatter(cdc)
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths")
#' flusight_hub_formatter(cdc, target = paste(horizon, "wk inc covid deaths"))
#' flusight_hub_formatter(cdc, target = "wk inc covid deaths", output_type = "quantile")
#' }
flusight_hub_formatter <- function(
object, ...,
.fcast_period = c("daily", "weekly")) {
Expand Down Expand Up @@ -93,11 +103,6 @@ flusight_hub_formatter.data.frame <- function(
object <- object %>%
# combine the predictions and the distribution
dplyr::mutate(.pred_distn = nested_quantiles(.pred_distn)) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred_distn = list(add_row(.pred_distn, values = .pred, quantile_levels = NA)),
.pred = NULL
) %>%
tidyr::unnest(.pred_distn) %>%
# now we create the correct column names
dplyr::rename(
Expand All @@ -106,7 +111,7 @@ flusight_hub_formatter.data.frame <- function(
reference_date = forecast_date
) %>%
# convert to fips codes, and add any constant cols passed in ...
dplyr::mutate(location = abbr_to_fips(tolower(geo_value)), geo_value = NULL)
dplyr::mutate(location = abbr_to_location(tolower(geo_value)), geo_value = NULL)

# create target_end_date / horizon, depending on what is available
pp <- ifelse(match.arg(.fcast_period) == "daily", 1L, 7L)
Expand Down
4 changes: 2 additions & 2 deletions man/cdc_baseline_args_list.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 16 additions & 15 deletions man/flusight_hub_formatter.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 64a820d

Please sign in to comment.