Skip to content

Commit

Permalink
Merge pull request #280 from cmu-delphi/277-recipe-printing-bug
Browse files Browse the repository at this point in the history
277 recipe printing bug
  • Loading branch information
dajmcdon authored Dec 23, 2023
2 parents 6be79f1 + d22efc5 commit 2d05865
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 208 deletions.
62 changes: 33 additions & 29 deletions R/canned-epipred.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,39 +58,43 @@ print.alist <- function(x, ...) {

#' @export
print.canned_epipred <- function(x, name, ...) {
cat("\n")
bullet <- "\u2022"
header <- glue::glue("A basic forecaster of type {name}")
header <- cli::rule(header, line = 2)
cat_line(header)
cat("\n")

date_created <- glue::glue(
"This forecaster was fit on {format(x$metadata$forecast_created)}"
d <- cli::cli_div(theme = list(rule = list("line-type" = "double")))
cli::cli_rule("A basic forecaster of type {name}")
cli::cli_end(d)
cli::cli_text("")
cli::cli_text(
"This forecaster was fit on {.field {format(x$metadata$forecast_created)}}."
)
cat_line(date_created)
cat("\n")

cat_line("Training data was an `epi_df` with")
cat_line(glue::glue("\u2022 Geography: {x$metadata$training$geo_type},"))
cat_line(glue::glue("{bullet} Time type: {x$metadata$training$time_type},"))
cat_line(glue::glue("{bullet} Using data up-to-date as of: {format(x$metadata$training$as_of)}."))
cli::cli_text("")
cli::cli_text("Training data was an {.cls epi_df} with:")
cli::cli_ul(c(
"Geography: {.field {x$metadata$training$geo_type}},",
"Time type: {.field {x$metadata$training$time_type}},",
"Using data up-to-date as of: {.field {format(x$metadata$training$as_of)}}."
))
cli::cli_text("")

cat("\n")
header <- cli::rule("Predictions")
cat_line(header)
cat("\n")
cli::cli_rule("Predictions")
cli::cli_text("")

n_geos <- dplyr::n_distinct(x$predictions$geo_value)
fds <- unique(x$predictions$forecast_date)
tds <- unique(x$predictions$target_date)

cat_line(
glue::glue("A total of {nrow(x$predictions)} predictions are available for")
fds <- cli::cli_vec(
unique(x$predictions$forecast_date),
list("vec-trunc" = 5)
)
tds <- cli::cli_vec(
unique(x$predictions$target_date),
list("vec-trunc" = 5)
)
cat_line(glue::glue("{bullet} {n_geos} unique geographic regions,"))
cat_line(glue::glue("{bullet} At forecast dates: {fds},"))
cat_line(glue::glue("{bullet} For target dates: {tds}."))

cat("\n")
cli::cli_text(c(
"A total of {.val {nrow(x$predictions)}} prediction{?s}",
" {?is/are} available for"
))
cli::cli_ul(c(
"{.val {n_geos}} unique geographic region{?s},",
"At forecast date{?s}: {.val {fds}},",
"For target date{?s}: {.val {tds}}."
))
cli::cli_text("")
}
47 changes: 19 additions & 28 deletions R/epi_recipe.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,15 @@ adjust_epi_recipe.epi_recipe <- function(
prep.epi_recipe <- function(
x, training = NULL, fresh = FALSE, verbose = FALSE,
retain = TRUE, log_changes = FALSE, strings_as_factors = TRUE, ...) {
if (is.null(training)) {
cli::cli_warn(c(
"!" = "No training data was supplied to {.fn prep}.",
"!" = "Unlike a {.cls recipe}, an {.cls epi_recipe} does not ",
"!" = "store the full template data in the object.",
"!" = "Please supply the training data to the {.fn prep} function,",
"!" = "to avoid addtional warning messages."
))
}
training <- recipes:::check_training_set(training, x, fresh)
training <- epi_check_training_set(training, x)
training <- dplyr::relocate(training, tidyselect::all_of(epi_keys(training)))
Expand Down Expand Up @@ -598,12 +607,12 @@ print.epi_recipe <- function(x, form_width = 30, ...) {
cli::cli_h3("Operations")
}

i <- 1
for (step in x$steps) {
cat(paste0(i, ". "))
print(step, form_width = form_width)
i <- i + 1
}
fmt <- cli::cli_fmt({
for (step in x$steps) {
print(step, form_width = form_width)
}
})
cli::cli_ol(fmt)
cli::cli_end()

invisible(x)
Expand All @@ -614,20 +623,12 @@ print_preprocessor_recipe <- function(x, ...) {
recipe <- workflows::extract_preprocessor(x)
steps <- recipe$steps
n_steps <- length(steps)
if (n_steps == 1L) {
step <- "Step"
} else {
step <- "Steps"
}
n_steps_msg <- glue::glue("{n_steps} Recipe {step}")
cat_line(n_steps_msg)
cli::cli_text("{n_steps} Recipe step{?s}.")

if (n_steps == 0L) {
return(invisible(x))
}

cat_line("")

step_names <- map_chr(steps, workflows:::pull_step_name)

if (n_steps <= 10L) {
Expand All @@ -638,17 +639,8 @@ print_preprocessor_recipe <- function(x, ...) {
extra_steps <- n_steps - 10L
step_names <- step_names[1:10]

if (extra_steps == 1L) {
step <- "step"
} else {
step <- "steps"
}

extra_dots <- "..."
extra_msg <- glue::glue("and {extra_steps} more {step}.")

cli::cli_ol(step_names)
cli::cli_bullets(c(extra_dots, extra_msg))
cli::cli_bullets("... and {extra_steps} more step{?s}.")
invisible(x)
}

Expand All @@ -664,9 +656,8 @@ print_preprocessor <- function(x) {
return(invisible(x))
}

cat_line("")
header <- cli::rule("Preprocessor")
cat_line(header)
cli::cli_rule("Preprocessor")
cli::cli_text("")

if (has_preprocessor_formula) {
workflows:::print_preprocessor_formula(x)
Expand Down
45 changes: 1 addition & 44 deletions R/epi_workflow.R
Original file line number Diff line number Diff line change
Expand Up @@ -323,50 +323,7 @@ print.epi_workflow <- function(x, ...) {
print_header(x)
print_preprocessor(x)
# workflows:::print_case_weights(x)
workflows:::print_model(x)
print_model(x)
print_postprocessor(x)
invisible(x)
}

print_header <- function(x) {
# same as in workflows but with a postprocessor
trained <- ifelse(workflows::is_trained_workflow(x), " [trained]", "")

header <- glue::glue("Epi Workflow{trained}")
header <- cli::rule(header, line = 2)
cat_line(header)

preprocessor_msg <- cli::style_italic("Preprocessor:")

if (workflows:::has_preprocessor_formula(x)) {
preprocessor <- "Formula"
} else if (workflows:::has_preprocessor_recipe(x)) {
preprocessor <- "Recipe"
} else if (workflows:::has_preprocessor_variables(x)) {
preprocessor <- "Variables"
} else {
preprocessor <- "None"
}

preprocessor_msg <- glue::glue("{preprocessor_msg} {preprocessor}")
cat_line(preprocessor_msg)

spec_msg <- cli::style_italic("Model:")

if (workflows:::has_spec(x)) {
spec <- class(workflows::extract_spec_parsnip(x))[[1]]
spec <- glue::glue("{spec}()")
} else {
spec <- "None"
}

spec_msg <- glue::glue("{spec_msg} {spec}")
cat_line(spec_msg)

postprocessor_msg <- cli::style_italic("Postprocessor:")
postprocessor <- ifelse(has_postprocessor_frosting(x), "Frosting", "None")
postprocessor_msg <- glue::glue("{postprocessor_msg} {postprocessor}")
cat_line(postprocessor_msg)

invisible(x)
}
61 changes: 7 additions & 54 deletions R/frosting.R
Original file line number Diff line number Diff line change
Expand Up @@ -415,60 +415,13 @@ print.frosting <- function(x, form_width = 30, ...) {
cli::cli_h1("Frosting")

if (!is.null(x$layers)) cli::cli_h3("Layers")
i <- 1
for (layer in x$layers) {
cat(paste0(i, ". "))
print(layer, form_width = form_width)
i <- i + 1
}
cli::cli_end()
invisible(x)
}

# Currently only used in the workflow printing
print_frosting <- function(x, ...) {
layers <- x$layers
n_layers <- length(layers)
layer <- ifelse(n_layers == 1L, "Layer", "Layers")
n_layers_msg <- glue::glue("{n_layers} Frosting {layer}")
cat_line(n_layers_msg)

if (n_layers == 0L) {
return(invisible(x))
}

cat_line("")

layer_names <- map_chr(layers, pull_layer_name)

if (n_layers <= 10L) {
cli::cli_ol(layer_names)
return(invisible(x))
}

extra_layers <- n_layers - 10L
layer_names <- layer_names[1:10]

layer <- ifelse(extra_layers == 1L, "layer", "layers")

extra_dots <- "..."
extra_msg <- glue::glue("and {extra_layers} more {layer}.")

cli::cli_ol(layer_names)
cli::cli_bullets(c(extra_dots, extra_msg))
invisible(x)
}

print_postprocessor <- function(x) {
if (!has_postprocessor_frosting(x)) {
return(invisible(x))
}

header <- cli::rule("Postprocessor")
cat_line(header)

frost <- extract_frosting(x)
print_frosting(frost)

fmt <- cli::cli_fmt({
for (layer in x$layers) {
print(layer, form_width = form_width)
}
})
cli::cli_ol(fmt)
cli::cli_end()
invisible(x)
}
Loading

0 comments on commit 2d05865

Please sign in to comment.