Skip to content

Commit

Permalink
functions, remains to check validity
Browse files Browse the repository at this point in the history
  • Loading branch information
dajmcdon committed Sep 23, 2023
1 parent cea1599 commit d606741
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 42 deletions.
5 changes: 5 additions & 0 deletions R/compat-purrr.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ map_chr <- function(.x, .f, ...) {
.rlang_purrr_map_mold(.x, .f, character(1), ...)
}

map_vec <- function(.x, .f, ...) {
out <- map(.x, .f, ...)
vctrs::list_unchop(out)
}

map_dfr <- function(.x, .f, ..., .id = NULL) {
.f <- rlang::as_function(.f, env = rlang::global_env())
res <- map(.x, .f, ...)
Expand Down
144 changes: 106 additions & 38 deletions R/layer_cdc_flatline_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ layer_cdc_flatline_quantiles <- function(
nsims = 1e5,
by_key = "geo_value",
symmetrize = FALSE,
nonneg = TRUE,
id = rand_id("cdc_baseline_bands")) {

rlang::check_dots_empty()
Expand All @@ -15,7 +16,7 @@ layer_cdc_flatline_quantiles <- function(
arg_is_pos_int(nsims)
arg_is_scalar(nsims)
arg_is_chr_scalar(id)
arg_is_lgl_scalar(symmetrize)
arg_is_lgl_scalar(symmetrize, nonneg)
arg_is_chr(by_key, allow_null = TRUE, allow_na = TRUE, allow_empty = TRUE)

add_layer(
Expand All @@ -26,6 +27,7 @@ layer_cdc_flatline_quantiles <- function(
nsims = nsims,
by_key = by_key,
symmetrize = symmetrize,
nonneg = nonneg,
id = id
)
)
Expand All @@ -37,66 +39,132 @@ layer_cdc_flatline_quantiles_new <- function(
nsims,
by_key,
symmetrize,
nonneg,
id
) {
layer(
"cdc_flatline_quantiles",
aheads,
quantiles,
nsims,
by_key,
symmetrize,
id
aheads = aheads,
quantiles = quantiles,
nsims = nsims,
by_key = by_key,
symmetrize = symmetrize,
nonneg = nonneg,
id = id
)
}

#' @export
slather.layer_cdc_flatline_quantiles <-
function(object, components, workflow, new_data, ...) {
the_fit <- workflows::extract_fit_parsnip(workflow)
s <- ifelse(object$symmetrize, -1, NA)
if (!inherits(the_fit, "_flatline")) {
cli::cli_warn(
c("Predictions for this workflow were not produced by the {.cls flatline}",
"{.pkg parsnip} engine. Results may be unexpected. See {.fn epipredict::flatline}.")
)
}
p <- components$predictions
ek <- kill_time_value(epi_keys_mold(components$mold))
r <- grab_residuals(the_fit, components)

## Handle any grouping requests
avail_grps <- character(0L)
if (length(object$by_key) > 0L) {
key_cols <- dplyr::bind_cols(
geo_value = components$mold$extras$roles$geo_value,
components$mold$extras$roles$key
)
common <- intersect(object$by_key, names(key_cols))
excess <- setdiff(object$by_key, names(key_cols))
if (length(excess) > 0L) {
rlang::warn(
"Requested residual grouping key(s) {excess} are unavailable ",
"in the original data. Grouping by the remainder: {common}."
)
cols_in_preds <- hardhat::check_column_names(p, object$by_key)
if (!cols_in_preds$ok) {
cli::cli_warn(c(
"Predicted values are missing key columns: {.var cols_in_preds$missing_names}.",
"Ignoring these."
))
}
if (length(common) > 0L) {
r <- r %>% dplyr::select(tidyselect::any_of(c(common, ".resid")))
common_in_r <- common[common %in% names(r)]
if (length(common_in_r) != length(common)) {
rlang::warn(
"Some grouping keys are not in data.frame returned by the",
"`residuals()` method. Groupings may not be correct."
)
if (inherits(the_fit, "_flatline")) {
cols_in_resids <- hardhat::check_column_names(r, object$by_key)
if (!cols_in_resids$ok) {
cli::cli_warn(c(
"Existing residuals are missing key columns: {.var cols_in_resids$missing_names}.",
"Ignoring these."
))
}
# use only the keys that are in the predictions and requested.
avail_grps <- intersect(ek, setdiff(
object$by_key,
c(cols_in_preds$missing_names, cols_in_resids$missing_names)
))
} else { # not flatline, but we'll try
key_cols <- dplyr::bind_cols(
geo_value = components$mold$extras$roles$geo_value,
components$mold$extras$roles$key
)
cols_in_resids <- hardhat::check_column_names(key_cols, object$by_key)
if (!cols_in_resids$ok) {
cli::cli_warn(c(
"Requested residuals are missing key columns: {.var cols_in_resids$missing_names}.",
"Ignoring these."
))
}
r <- dplyr::bind_cols(key_cols, r) %>%
dplyr::group_by(!!!rlang::syms(common))
avail_grps <- intersect(names(key_cols), setdiff(
object$by_key,
c(cols_in_preds$missing_names, cols_in_resids$missing_names)
))
r <- dplyr::bind_cols(key_cols, r)
}
}
r <- r %>%
dplyr::select(tidyselect::all_of(c(avail_grps, ".resid"))) %>%
dplyr::group_by(!!!rlang::syms(avail_grps)) %>%
dplyr::summarise(.resid = list(.resid), .groups = "drop")

res <- dplyr::left_join(p, r, by = avail_grps) %>%
dplyr::rowwise() %>%
dplyr::mutate(
.pred_distn_all = propogate_samples(
.resid, .pred, object$quantiles,
object$aheads, object$nsim, object$symmetrize, object$nonneg
)
) %>%
dplyr::select(tidyselect::all_of(c(avail_grps, ".pred_distn_all")))





# always return components
# res <- check_pname(res, components$predictions, object)
components$predictions <- dplyr::left_join(
components$predictions,
res,
by = avail_grps
)
components
}

propogate_samples <- function(x, p, horizon, nsim, symmetrize) {
samp <- quantile(x, probs = c(0, seq_len(nsim)) / nsim)
propogate_samples <- function(
r, p, quantiles, aheads, nsim, symmetrize, nonneg) {
max_ahead <- max(aheads)
samp <- quantile(r, probs = c(0, seq_len(nsim - 1)) / (nsim - 1), na.rm = TRUE)
res <- list()

for (iter in seq(horizon)) {}
# p should be all the same
p <- max(p, na.rm = TRUE)

raw <- samp + p
if (nonneg) raw <- pmax(0, raw)
res[[1]] <- raw
if (max_ahead > 1L) {
for (iter in 2:max_ahead) {
samp <- shuffle(samp)
raw <- raw + samp
if (symmetrize) symmetric <- raw - (median(raw) + p)
else symmetric <- raw
if (nonneg) symmetric <- pmax(0, symmetric)
res[[iter]] <- symmetric
}
}
res <- res[aheads]
list(tibble::tibble(
aheads = aheads,
.pred_distn = map_vec(
res, ~ dist_quantiles(quantile(.x, quantiles), tau = quantiles)
)
))
}

shuffle <- function(x) {
stopifnot(is.vector(x))
sample(x, length(x), replace = FALSE)
}
8 changes: 4 additions & 4 deletions R/layer_residual_quantiles.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ grab_residuals <- function(the_fit, components) {
if (".resid" %in% names(r)) { # success
return(r)
} else { # failure
rlang::warn(c(
"The `residuals()` method for objects of class {cl} results in",
cli::cli_warn(c(
"The `residuals()` method for {.cls cl} objects results in",
"a data frame without a column named `.resid`.",
i = "Residual quantiles will be calculated directly from the",
i = "difference between predictions and observations.",
Expand All @@ -152,8 +152,8 @@ grab_residuals <- function(the_fit, components) {
} else if (is.vector(drop(r))) { # also success
return(tibble(.resid = drop(r)))
} else { # failure
rlang::warn(c(
"The `residuals()` method for objects of class {cl} results in an",
cli::cli_warn(c(
"The `residuals()` method for {.cls cl} objects results in an",
"object that is neither a data frame with a column named `.resid`,",
"nor something coercible to a vector.",
i = "Residual quantiles will be calculated directly from the",
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-propogate_samples.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("propogate_samples", {
r <- -30:50
p <- 40
quantiles <- 1:9 / 10
aheads <- c(2, 4, 7)
nsim <- 100

})
5 changes: 5 additions & 0 deletions tests/testthat/test-shuffle.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
test_that("shuffle works", {
expect_error(shuffle(matrix(NA, 2, 2)))
expect_length(shuffle(1:10), 10L)
expect_identical(sort(shuffle(1:10)), 1:10)
})

0 comments on commit d606741

Please sign in to comment.