Skip to content

Commit

Permalink
Made requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rachlobay committed Oct 24, 2023
1 parent d3df2c9 commit f7b38a5
Show file tree
Hide file tree
Showing 20 changed files with 418 additions and 128 deletions.
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ S3method(snap,default)
S3method(snap,dist_default)
S3method(snap,dist_quantiles)
S3method(snap,distribution)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
S3method(update_model,epi_workflow)
S3method(vec_ptype_abbr,dist_quantiles)
Expand Down
88 changes: 60 additions & 28 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ epi_recipe.default <- function(x, ...) {
#'
#' @export
#' @examples
#' library(epiprocess)
#' library(dplyr)
#' library(recipes)
#'
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value > "2021-08-01") %>%
#' dplyr::arrange(geo_value, time_value)
Expand Down Expand Up @@ -243,13 +239,13 @@ is_epi_recipe <- function(x) {
#' [workflows::add_recipe()] but sets a different
#' default blueprint to automatically handle [epiprocess::epi_df] data.
#'
#' @param x A `workflow` `or `epi_workflow`
#' @param x A `workflow` or `epi_workflow`
#'
#' @param recipe A recipe created using [recipes::recipe()].
#' @param recipe An epi recipe or recipe
#'
#' @param ... Not used
#'
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing.
#' @param blueprint A hardhat blueprint used for fine tuning the preprocessing
#'
#' [default_epi_recipe_blueprint()] is used.
#'
Expand All @@ -261,10 +257,6 @@ is_epi_recipe <- function(x) {
#'
#' @export
#' @examples
#' library(epiprocess)
#' library(dplyr)
#' library(recipes)
#'
#' jhu <- case_death_rate_subset %>%
#' filter(time_value > "2021-08-01") %>%
#' dplyr::arrange(geo_value, time_value)
Expand All @@ -284,7 +276,7 @@ is_epi_recipe <- function(x) {
#' r2 <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7)

#'
#' workflow <- update_epi_recipe(workflow, r2)
#'
#' workflow <- remove_epi_recipe(workflow)
Expand Down Expand Up @@ -318,9 +310,10 @@ remove_epi_recipe <- function(x) {

#' @rdname add_epi_recipe
#' @export
update_epi_recipe <- function(x, recipe, ..., blueprint = NULL) {
update_epi_recipe <- function(x, recipe, ..., blueprint = default_epi_recipe_blueprint()) {
rlang::check_dots_empty()
x <- remove_epi_recipe(x)
add_epi_recipe(x, recipe, blueprint = blueprint)
add_epi_recipe(x, recipe, ..., blueprint = blueprint)
}

#' Adjust a step in an `epi_workflow` or `epi_recipe`
Expand All @@ -330,13 +323,15 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = NULL) {
#'
#'
#' @details This function can either adjust a step in a `epi_recipe` object
#' or a step from a `epi_recipe` object in an `epi_workflow`. In any case, the
#' argument name and update value must be inputted as `...`.
#' See the examples below for brief illustrations of both types of updates.
#' or a step from a `epi_recipe` object in an `epi_workflow`. The step to be
#' adjusted is indicated by either the step number or name (if a name is used,
#' it must be unique). In either case, the argument name and update value
#' must be inputted as `...`. See the examples below for brief
#' illustrations of the different types of updates.
#'
#' @param x A `epi_workflow` or `epi_recipe` object
#'
#' @param step_num the number of the step to adjust
#' @param which_step the number or name of the step to adjust
#'
#' @param ... Used to input a parameter adjustment
#'
Expand All @@ -360,35 +355,72 @@ update_epi_recipe <- function(x, recipe, ..., blueprint = NULL) {
#'
#' # Adjust `step_epi_ahead` to have an ahead value of 14
#' # in the `epi_workflow`
#' wf2 = wf %>% adjust_epi_recipe(step_num = 2, ahead = 14)
#' # Option 1. Using the step number:
#' wf2 <- wf %>% adjust_epi_recipe(which_step = 2, ahead = 14)
#' workflows::extract_preprocessor(wf2)
#' # Option 2. Using the step name:
#' wf3 <- wf %>% adjust_epi_recipe(which_step = "step_epi_ahead", ahead = 14)
#' workflows::extract_preprocessor(wf3)
#'
#' # Adjust `step_epi_ahead` to have an ahead value of 14
#' # in the `epi_recipe`
#' r2 = r %>% adjust_epi_recipe(step_num = 2, ahead = 14)
#' # Option 1. Using the step number
#' r2 <- r %>% adjust_epi_recipe(which_step = 2, ahead = 14)
#' r2
#' # Option 2. Using the step name
#' r3 <- r %>% adjust_epi_recipe(which_step = "step_epi_ahead", ahead = 14)
#' r3
#'
adjust_epi_recipe <- function(x, step_num, ..., blueprint = default_epi_recipe_blueprint()) {
adjust_epi_recipe <- function(x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
UseMethod("adjust_epi_recipe")
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_workflow <- function(
x, step_num, ..., blueprint = default_epi_recipe_blueprint()) {

recipe <- workflows::extract_preprocessor(x)
recipe$steps[[step_num]] <- update(recipe$steps[[step_num]], ...)
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
recipe <- adjust_epi_recipe(workflows::extract_preprocessor(x), which_step, ...)

update_epi_recipe(x, recipe, blueprint = blueprint)
}

#' @rdname adjust_epi_recipe
#' @export
adjust_epi_recipe.epi_recipe <- function(
x, step_num, ..., blueprint = default_epi_recipe_blueprint()) {

x$steps[[step_num]] <- update(x$steps[[step_num]], ...)
x, which_step, ..., blueprint = default_epi_recipe_blueprint()) {
if (!(is.numeric(which_step) || is.character(which_step))) {
rlang::abort(
paste0(
"The step name (`which_step`) must be a number or a character."
)
)
} else if (is.numeric(which_step)) {
x$steps[[which_step]] <- update(x$steps[[which_step]], ...)
} else {
step_names <- map_chr(x$steps, ~ attr(.x, "class")[1])

if (!(which_step %in% step_names)) {
rlang::abort(
paste0(
"The step name (`which_step`) is not in the `epi_recipe` step names: ",
paste0(step_names, collapse = ", "),
"."
)
)
}
which_step_idx <- which(step_names == which_step)
if (length(which_step_idx) == 1) {
x$steps[[which_step_idx]] <- update(x$steps[[which_step_idx]], ...)
} else {
rlang::abort(
paste0(
"The step name (`which_step`) is not unique. Matches steps: ",
paste0(which_step_idx, collapse = ", "),
"."
)
)
}
}
x
}

Expand Down
25 changes: 15 additions & 10 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,10 @@ is_epi_workflow <- function(x) {
#' @export
#' @examples
#' jhu <- case_death_rate_subset %>%
#' dplyr::filter(time_value > "2021-11-01",
#' geo_value %in% c("ak", "ca", "ny"))
#' dplyr::filter(
#' time_value > "2021-11-01",
#' geo_value %in% c("ak", "ca", "ny")
#' )
#'
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
Expand All @@ -110,23 +112,26 @@ is_epi_workflow <- function(x) {
#' wf <- remove_model(wf)
#' wf
#' @export
add_model <- function(x, spec, ..., formula = NULL)
UseMethod('add_model')
add_model <- function(x, spec, ..., formula = NULL) {
UseMethod("add_model")
}

#' @rdname add_model
#' @export
remove_model <- function(x)
UseMethod('remove_model')
remove_model <- function(x) {
UseMethod("remove_model")
}

#' @rdname add_model
#' @export
update_model <- function(x, spec, ..., formula = NULL)
UseMethod('update_model')
update_model <- function(x, spec, ..., formula = NULL) {
UseMethod("update_model")
}

#' @rdname add_model
#' @export
add_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
workflows::add_model(x, spec, formula = formula)
workflows::add_model(x, spec, ..., formula = formula)
}

#' @rdname add_model
Expand All @@ -151,7 +156,7 @@ remove_model.epi_workflow <- function(x) {
update_model.epi_workflow <- function(x, spec, ..., formula = NULL) {
rlang::check_dots_empty()
x <- remove_model(x)
workflows::add_model(x, spec, formula = formula)
workflows::add_model(x, spec, ..., formula = formula)
}


Expand Down
87 changes: 65 additions & 22 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
#' filter(time_value > "2021-11-01", geo_value %in% c("ak", "ca", "ny"))
#' r <- epi_recipe(jhu) %>%
#' step_epi_lag(death_rate, lag = c(0, 7, 14)) %>%
#' step_epi_ahead(death_rate, ahead = 7) %>%
#' step_epi_naomit()
#' step_epi_ahead(death_rate, ahead = 7)
#'
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#' latest <- jhu %>%
Expand Down Expand Up @@ -41,7 +40,7 @@
add_frosting <- function(x, frosting, ...) {
rlang::check_dots_empty()
action <- workflows:::new_action_post(frosting = frosting)
epi_add_action(x, action, "frosting")
epi_add_action(x, action, "frosting", ...)
}


Expand Down Expand Up @@ -96,6 +95,7 @@ validate_has_postprocessor <- function(x, ..., call = caller_env()) {
#' @rdname add_frosting
#' @export
update_frosting <- function(x, frosting, ...) {
rlang::check_dots_empty()
x <- remove_frosting(x)
add_frosting(x, frosting)
}
Expand All @@ -108,13 +108,15 @@ update_frosting <- function(x, frosting, ...) {
#'
#'
#' @details This function can either adjust a layer in a `frosting` object
#' or a layer from a `frosting` object in an `epi_workflow`. In any case, the
#' argument name and update value must be inputted as `...`.
#' See the examples below for brief illustrations of both types of updates.
#' or a layer from a `frosting` object in an `epi_workflow`. The layer to be
#' adjusted is indicated by either the layer number or name (if a name is used,
#' it must be unique). In either case, the argument name and update value
#' must be inputted as `...`. See the examples below for brief
#' illustrations of the different types of updates.
#'
#' @param x An `epi_workflow` or `frosting` object
#'
#' @param layer_num the number of the layer to adjust
#' @param which_layer the number or name of the layer to adjust
#'
#' @param ... Used to input a parameter adjustment
#'
Expand All @@ -133,44 +135,85 @@ update_frosting <- function(x, frosting, ...) {
#' wf <- epi_workflow(r, parsnip::linear_reg()) %>% fit(jhu)
#'
#' # in the frosting from the workflow
#' f1 <- frosting() %>% layer_predict() %>% layer_threshold(.pred)
#' f1 <- frosting() %>%
#' layer_predict() %>%
#' layer_threshold(.pred)
#'
#' wf2 = wf %>% add_frosting(f1)
#' wf2 <- wf %>% add_frosting(f1)
#'
#' # Adjust `layer_threshold` to have an upper bound of 1
#' # in the `epi_workflow`
#' wf2 = wf2 %>% adjust_frosting(layer_num = 2, upper = 1)
#' # Option 1. Using the layer number:
#' wf2 <- wf2 %>% adjust_frosting(which_layer = 2, upper = 1)
#' extract_frosting(wf2)
#' # Option 2. Using the layer name:
#' wf3 <- wf2 %>% adjust_frosting(which_layer = "layer_threshold", upper = 1)
#' extract_frosting(wf3)
#'
#' # Adjust `layer_threshold` to have an upper bound of 1
#' # Adjust `layer_threshold` to have an upper bound of 5
#' # in the `frosting` object
#' f2 = f1 %>% adjust_frosting(layer_num = 2, upper = 5)
#' extract_frosting(wf2)
#'
adjust_frosting <- function(x, layer_num, ...) {
#' # Option 1. Using the layer number:
#' f2 <- f1 %>% adjust_frosting(which_layer = 2, upper = 5)
#' f2
#' # Option 2. Using the layer name
#' f3 <- f1 %>% adjust_frosting(which_layer = "layer_threshold", upper = 5)
#' f3
#'
adjust_frosting <- function(x, which_layer, ...) {
UseMethod("adjust_frosting")
}

#' @rdname adjust_frosting
#' @export
adjust_frosting.epi_workflow <- function(
x, layer_num, ...) {

frosting <- extract_frosting(x)
frosting$layers[[layer_num]] <- update(frosting$layers[[layer_num]], ...)
x, which_layer, ...) {
frosting <- adjust_frosting(extract_frosting(x), which_layer, ...)

update_frosting(x, frosting)
}

#' @rdname adjust_frosting
#' @export
adjust_frosting.frosting <- function(
x, layer_num, ...) {

x$layers[[layer_num]] <- update(x$layers[[layer_num]], ...)
x, which_layer, ...) {
if (!(is.numeric(which_layer) || is.character(which_layer))) {
rlang::abort(
paste0(
"The layer name (`which_layer`) must be a number or a character."
)
)
} else if (is.numeric(which_layer)) {
x$layers[[which_layer]] <- update(x$layers[[which_layer]], ...)
} else {
layer_names <- map_chr(x$layers, ~ attr(.x, "class")[1])

if (!(which_layer %in% layer_names)) {
rlang::abort(
paste0(
"The layer name (`which_layer`) is not in the `frosting` layer names: ",
paste0(layer_names, collapse = ", "),
"."
)
)
}
which_layer_idx <- which(layer_names == which_layer)
if (length(which_layer_idx) == 1) {
x$layers[[which_layer_idx]] <- update(x$layers[[which_layer_idx]], ...)
} else {
rlang::abort(
paste0(
"The layer name (`which_layer`) is not unique. Matches layers: ",
paste0(which_layer_idx, collapse = ", "),
"."
)
)
}
}
x
}



#' @importFrom rlang caller_env
add_postprocessor <- function(x, postprocessor, ..., call = caller_env()) {
rlang::check_dots_empty()
Expand Down
Loading

0 comments on commit f7b38a5

Please sign in to comment.