Skip to content

Commit

Permalink
Merge branch 'main' into 184-add-outputs-to-sim-data
Browse files Browse the repository at this point in the history
  • Loading branch information
dylanhmorris authored Oct 23, 2024
2 parents 8ade72d + e2e4fa1 commit 0ce9260
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 23 deletions.
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# wwinference dev

- Add wastewater data into the forecast period to output in `generate_simulated_data()` function and as package data. Also adds subpopulation-level
hospital admissions to output of function and package data. ([#184](https://github.com/CDCgov/ww-inference-model/issues/184))
- Modify `plot_forecasted_counts()` so that it does not require an evaluation dataset ([#218](https://github.com/CDCgov/ww-inference-model/pull/218))

# wwinference 0.1.0

Expand All @@ -14,4 +16,4 @@ As it's written, the package is intended to allow users to do the following:
- Validate input data validation with informative error messaging ([#37](https://github.com/CDCgov/ww-inference-model/issues/37), [#54](https://github.com/CDCgov/ww-inference-model/issues/54))
- Provide a wrapper function to generate forward simulated data with user-specified variables. It calls a number of functions to perform specific model components ([#27](https://github.com/CDCgov/ww-inference-model/issues/27))
- Contains S3 class methods applied to the output of the main model wrapper function, the `wwinference_fit` class object ([#58](https://github.com/CDCgov/ww-inference-model/issues/58)).
- Wastewater concentration data is expected to be in log scale ([#122](https://onetakeda.box.com/s/pju273g5khx3y3cwoae2zwv3e7vu03x3)).
- Wastewater concentration data is expected to be in log scale ([#122](https://github.com/CDCgov/ww-inference-model/pull/122)).
26 changes: 16 additions & 10 deletions R/figures.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
#' @param count_data_eval A dataframe containing the count data we will
#' evaluate the forecasts against. Must contain the columns `date` and
#' a column indicating the count data to evaluate against, with the name
#' of that column specified as the `count_data_eval_col_name`
#' of that column specified as the `count_data_eval_col_name`. Default is
#' NULL, which will result in no evaluation data being plotted.
#' @param count_data_eval_col_name string indicating the name of the count
#' data to evaluate against the forecasted count data
#' data to evaluate against the forecasted count data. Default is NULL,
#' corresponding to no evaluation data being plotted.
#' @param forecast_date A string indicating the date we made the forecast, for
#' plotting, in ISO8601 format YYYY-MM-DD
#' @param count_type A string indicating what data the counts refer to,
Expand All @@ -25,9 +27,9 @@
#' @export
#'
get_plot_forecasted_counts <- function(draws,
count_data_eval,
count_data_eval_col_name,
forecast_date,
count_data_eval = NULL,
count_data_eval_col_name = NULL,
count_type = "hospital admissions",
n_draws_to_plot = 100) {
n_draws_available <- max(draws$draw)
Expand Down Expand Up @@ -55,11 +57,6 @@ get_plot_forecasted_counts <- function(draws,
aes(x = .data$date, y = .data$pred_value, group = .data$draw),
color = "red4", alpha = 0.1, linewidth = 0.2
) +
geom_point(
data = count_data_eval,
aes(x = .data$date, y = .data[[count_data_eval_col_name]]),
shape = 21, color = "black", fill = "white"
) +
geom_point(aes(x = .data$date, y = .data$observed_value)) +
geom_vline(
xintercept = lubridate::ymd(forecast_date),
Expand All @@ -85,6 +82,15 @@ get_plot_forecasted_counts <- function(draws,
vjust = 0.5, hjust = 0.5
)
)

if (!is.null(count_data_eval)) {
p <- p +
geom_point(
data = count_data_eval,
aes(x = .data$date, y = .data[[count_data_eval_col_name]]),
shape = 21, color = "black", fill = "white"
)
}
return(p)
}

Expand Down Expand Up @@ -132,7 +138,7 @@ get_plot_ww_conc <- function(draws,
aes(x = .data$date, y = .data$observed_value),
color = "blue", show.legend = FALSE, size = 0.5
) +
facet_wrap(~lab_site_name, scales = "free") +
facet_wrap(~lab_site_name, scales = "free_y") +
geom_vline(
xintercept = lubridate::ymd(forecast_date),
linetype = "dashed"
Expand Down
3 changes: 2 additions & 1 deletion R/get_stan_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,8 @@ get_stan_data <- function(input_count_data,
arg_max_date = "forecast date"
)

# Validate both datasets if both are used----------------------------------
# if both datasets are used, validate that that they are
# compatible and consistent with each other
if (include_ww == 1) {
validate_data_jointly(
input_count_data = input_count_data,
Expand Down
16 changes: 9 additions & 7 deletions man/get_plot_forecasted_counts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

53 changes: 53 additions & 0 deletions tests/testthat/test_plots.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
t_length <- 127
forecast_date <- "2024-01-01"
data <- tibble::tibble(
date = seq(
from = lubridate::ymd("2023-10-01"),
to = lubridate::ymd("2023-10-01") + lubridate::days(t_length - 1),
by = "days"
),
observed_value = sample(10:25, t_length, replace = TRUE)
)

draws <- tibble::tibble()
for (i in 1:100) {
draws_i <- data |>
dplyr::mutate(
pred_value = observed_value +
runif(t_length, min = -10, max = 10),
draw = i
)
draws <- dplyr::bind_rows(draws, draws_i)
}

test_draws <- draws |>
dplyr::mutate(
observed_value = ifelse(date < forecast_date, observed_value, NA)
)

test_eval_data <- data |>
dplyr::rename("daily_hosp_admits_eval" = observed_value)




test_that("Test there is no error with eval data", {
expect_no_error(
get_plot_forecasted_counts(
draws = test_draws,
forecast_date = forecast_date,
count_data_eval = test_eval_data,
count_data_eval_col_name = "daily_hosp_admits_eval"
)
)
})


test_that("Test there is no error without eval data", {
expect_no_error(
get_plot_forecasted_counts(
draws = test_draws,
forecast_date = forecast_date
)
)
})
19 changes: 15 additions & 4 deletions vignettes/wwinference.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -465,13 +465,14 @@ will not include outliers that were flagged for exclusion. Data points
that are below the LOD will be plotted in blue.

```{r generating-figures, out.width='100%'}
plot_hosp <- get_plot_forecasted_counts(
plot_hosp_with_eval <- get_plot_forecasted_counts(
draws = draws$predicted_counts,
forecast_date = forecast_date,
count_data_eval = hosp_data_eval,
count_data_eval_col_name = "daily_hosp_admits_for_eval",
forecast_date = forecast_date
count_data_eval_col_name = "daily_hosp_admits_for_eval"
)
plot_hosp
plot_hosp_with_eval
plot_ww <- get_plot_ww_conc(draws$predicted_ww, forecast_date)
plot_ww
Expand All @@ -483,6 +484,16 @@ plot_subpop_rt <- get_plot_subpop_rt(draws$subpop_rt, forecast_date)
plot_subpop_rt
```

To plot the forecasts without the retrospectively observed hospital admissions,
simply don't pass them to the plotting function.
```{r plot-only-count-forecasts, out.width='100%'}
plot_hosp <- get_plot_forecasted_counts(
draws = draws$predicted_counts,
forecast_date = forecast_date
)
plot_hosp
```

The previous three are equivalent to calling the `plot` method of `wwinference_fit_draws` using the `what` argument:

```{r, out.width='100%'}
Expand Down

0 comments on commit 0ce9260

Please sign in to comment.