diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7ee90bd5..87d0ff6b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,99 +1,103 @@ # All available hooks: https://pre-commit.com/hooks.html # R specific hooks: https://github.com/lorenzwalthert/precommit repos: -# R -- repo: https://github.com/lorenzwalthert/precommit + # R + - repo: https://github.com/lorenzwalthert/precommit rev: v0.4.3 hooks: - - id: style-files - args: [--style_pkg=styler, --style_fun=tidyverse_style, - --cache-root=styler-perm] - - id: use-tidy-description - - id: lintr - - id: readme-rmd-rendered - - id: parsable-R - - id: no-browser-statement - - id: no-print-statement + - id: style-files + args: + [ + --style_pkg=styler, + --style_fun=tidyverse_style, + --cache-root=styler-perm, + ] + - id: use-tidy-description + - id: lintr + - id: readme-rmd-rendered + - id: parsable-R + - id: no-browser-statement + - id: no-print-statement exclude: '^tests/testthat/test-print\.R$' - - id: no-debug-statement - - id: deps-in-desc -- repo: https://github.com/pre-commit/pre-commit-hooks + - id: no-debug-statement + - id: deps-in-desc + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: - - id: check-added-large-files - args: ['--maxkb=200'] - - id: file-contents-sorter + - id: check-added-large-files + args: ["--maxkb=200"] + exclude: "tests/testthat/data/sample_fit.RDS" + - id: file-contents-sorter files: '^\.Rbuildignore$' - - id: end-of-file-fixer + - id: end-of-file-fixer exclude: '(\.Rd)|(tests/testthat/_snaps/)' - - id: check-yaml - - id: check-toml - - id: mixed-line-ending - args: ['--fix=lf'] - - id: trailing-whitespace - exclude: 'tests/testthat/_snaps/' -- repo: https://github.com/pre-commit-ci/pre-commit-ci-config + - id: check-yaml + - id: check-toml + - id: mixed-line-ending + args: ["--fix=lf"] + - id: trailing-whitespace + exclude: "tests/testthat/_snaps/" + - repo: https://github.com/pre-commit-ci/pre-commit-ci-config rev: v1.6.1 hooks: - # Only required when https://pre-commit.ci is used for config validation - - id: check-pre-commit-ci-config -- repo: local + # Only required when https://pre-commit.ci is used for config validation + - id: check-pre-commit-ci-config + - repo: local hooks: - - id: forbid-to-commit + - id: forbid-to-commit name: Don't commit common R artifacts entry: Cannot commit .Rhistory, .RData, .Rds or .rds. language: fail files: '\.(Rhistory|RData|Rds|rds)$' # `exclude: ` to allow committing specific files -##### -# Python -- repo: https://github.com/psf/black + ##### + # Python + - repo: https://github.com/psf/black rev: 24.8.0 hooks: - # if you have ipython notebooks, consider using - # `black-jupyter` hook instead - - id: black - args: ['--line-length', '79'] -- repo: https://github.com/PyCQA/isort + # if you have ipython notebooks, consider using + # `black-jupyter` hook instead + - id: black + args: ["--line-length", "79"] + - repo: https://github.com/PyCQA/isort rev: 5.13.2 hooks: - - id: isort - args: ['--profile', 'black', - '--line-length', '79'] -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.4 - hooks: - - id: ruff -##### -# Java -- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.14.0 - hooks: - - id: pretty-format-java - args: [--aosp,--autofix] -##### -# Julia -# Due to lack of first-class Julia support, this needs Julia local install -# and JuliaFormatter.jl installed in the library -# - repo: https://github.com/domluna/JuliaFormatter.jl -# rev: v1.0.39 -# hooks: -# - id: julia-formatter -##### -# Secrets -- repo: https://github.com/Yelp/detect-secrets + - id: isort + args: ["--profile", "black", "--line-length", "79"] + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.4 + hooks: + - id: ruff + ##### + # Java + - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks + rev: v2.14.0 + hooks: + - id: pretty-format-java + args: [--aosp, --autofix] + ##### + # Julia + # Due to lack of first-class Julia support, this needs Julia local install + # and JuliaFormatter.jl installed in the library + # - repo: https://github.com/domluna/JuliaFormatter.jl + # rev: v1.0.39 + # hooks: + # - id: julia-formatter + ##### + # Secrets + - repo: https://github.com/Yelp/detect-secrets rev: v1.5.0 hooks: - - id: detect-secrets - args: ['--baseline', '.secrets.baseline'] + - id: detect-secrets + args: ["--baseline", ".secrets.baseline"] exclude: package.lock.json ci: - autofix_commit_msg: | - [pre-commit.ci] auto fixes from pre-commit.com hooks + autofix_commit_msg: | + [pre-commit.ci] auto fixes from pre-commit.com hooks - for more information, see https://pre-commit.ci - autofix_prs: true - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' - autoupdate_schedule: weekly - submodules: false + for more information, see https://pre-commit.ci + autofix_prs: true + autoupdate_branch: "" + autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" + autoupdate_schedule: weekly + submodules: false diff --git a/DESCRIPTION b/DESCRIPTION index f66b34c3..1d1517b2 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,13 +32,6 @@ Imports: EpiNow2 (>= 1.4.0), rlang, rstan -LinkingTo: - BH (>= 1.66.0), - Rcpp (>= 0.12.0), - RcppEigen (>= 0.3.3.3.0), - RcppParallel (>= 5.0.1), - rstan (>= 2.26.0), - StanHeaders (>= 2.26.0) Additional_repositories: https://mc-stan.org/r-packages/ URL: https://cdcgov.github.io/cfa-epinow2-pipeline/ diff --git a/NAMESPACE b/NAMESPACE index 7016e7c9..50a5e104 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,12 +2,14 @@ export(apply_exclusions) export(download_from_azure_blob) +export(extract_diagnostics) export(fetch_blob_container) export(fetch_credential_from_env_var) export(fit_model) export(format_delay_interval) export(format_generation_interval) export(format_right_truncation) +export(low_case_count_diagnostic) export(read_data) export(read_disease_parameters) export(read_exclusions) diff --git a/R/extract_diagnostics.R b/R/extract_diagnostics.R index eb123d0a..c52cb3ac 100644 --- a/R/extract_diagnostics.R +++ b/R/extract_diagnostics.R @@ -1,7 +1,54 @@ +#' Extract diagnostic metrics from model fit and data +#' +#' This function extracts various diagnostic metrics from a fitted `epinow2` +#' model and provided data. It checks for low case counts and computes +#' diagnostics from the fitted model, including the mean acceptance +#' statistic, divergent transitions, maximum tree depth, and Rhat values. +#' These diagnostics are then flagged if they exceed specific thresholds, +#' and the results are returned as a data frame. +#' +#' @param fit A list containing the model fit object from `epinow2`, which +#' includes `estimates$fit`. +#' @param data A data frame containing the input data used in the model fit. +#' @param job_id A unique identifier for the job or task being processed. +#' @param task_id A unique identifier for the task being performed. +#' +#' +#' @return A \code{data.frame} containing the extracted diagnostic metrics. The +#' data frame includes the following columns: +#' \itemize{ +#' \item \code{diagnostic}: The name of the diagnostic metric. +#' \item \code{value}: The value of the diagnostic metric. +#' \item \code{state}: The state for which the model was run. +#' \item \code{disease}: The disease/pathogen being analyzed. +#' \item \code{job_id}: The unique identifier for the job. +#' \item \code{task_id}: The unique identifier for the task. +#' } +#' +#' @details +#' The following diagnostics are calculated: +#' \itemize{ +#' \item \code{mean_accept_stat}: The average acceptance statistic across +#' all chains. +#' \item \code{p_divergent}: The proportion of divergent transitions across +#' all samples. +#' \item \code{p_max_treedepth}: The proportion of samples that hit the +#' maximum tree depth. +#' \item \code{p_high_rhat}: The proportion of parameters with Rhat values +#' greater than 1.05, indicating potential convergence issues. +#' \item \code{low_case_count_flag}: A flag indicating if there are low case +#' counts in the data. See \code{low_case_count_diagnostic()} for more +#' information on this diagnostic. +#' \item \code{epinow2_diagnostic_flag}: A combined flag that indicates if +#' any diagnostic thresholds are exceeded. +#' } +#' @export extract_diagnostics <- function(fit, data, job_id, task_id) { low_case_count <- low_case_count_diagnostic(data) - epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit) + epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit, + inc_warmup = FALSE + ) mean_accept_stat <- mean( sapply(epinow2_diagnostics, function(x) mean(x[, "accept_stat__"])) ) @@ -17,37 +64,40 @@ extract_diagnostics <- function(fit, data, job_id, task_id) { rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05, na.rm = TRUE ) - any_diagnostic_flag <- any( + + # Combine all diagnostic flags into one flag + diagnostic_flag <- any( mean_accept_stat < 0.1, p_divergent > 0.0075, p_max_treedepth > 0.05, p_high_rhat > 0.0075 ) - diagnostic_df <- data.frame( - diagnostic = c( - "mean_accept_stat", - "p_divergent", - "p_max_treedepth", - "p_high_rhat", - "epinow2_diagnostic_flag", - "low_case_count_flag" - ), - value = c( - mean_accept_stat, - p_divergent, - p_max_treedepth, - p_high_rhat, - epinow2_diagnostic_flag, - low_case_count - ), - "state" = state, - disease = pathogen, + # Create individual vectors for the columns of the diagnostics data frame + diagnostic_names <- c( + "mean_accept_stat", + "p_divergent", + "p_max_treedepth", + "p_high_rhat", + "diagnostic_flag", + "low_case_count_flag" + ) + diagnostic_values <- c( + mean_accept_stat, + p_divergent, + p_max_treedepth, + p_high_rhat, + diagnostic_flag, + low_case_count ) - return(bind_rows(epinow2_df, data_df)) + data.frame( + diagnostic = diagnostic_names, + value = diagnostic_values, + job_id = job_id, + task_id = task_id + ) } - #' Calculate low case count diagnostic flag #' #' The diagnostic flag is TRUE if either of the _last_ two weeks of the dataset @@ -57,9 +107,12 @@ extract_diagnostics <- function(fit, data, job_id, task_id) { #' This function assumes that the `epinow2_df` input dataset has been #' "completed": that any implicit missingness has been made explicit. #' -#' @param df A dataframe as returned by [read_data()] +#' @param df A dataframe as returned by [read_data()]. The dataframe must +#' include columns such as `reference_date` (a date vector) and `confirm` +#' (the number of confirmed cases per day). #' -#' @return +#' @return A logical value (TRUE or FALSE) indicating whether either of the last +#' two weeks in the dataset had fewer than 10 cases per week. #' @export low_case_count_diagnostic <- function(df) { # Get the dates in the last and second-to-last weeks diff --git a/man/extract_diagnostics.Rd b/man/extract_diagnostics.Rd new file mode 100644 index 00000000..5356e71b --- /dev/null +++ b/man/extract_diagnostics.Rd @@ -0,0 +1,56 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_diagnostics.R +\name{extract_diagnostics} +\alias{extract_diagnostics} +\title{Extract diagnostic metrics from model fit and data} +\usage{ +extract_diagnostics(fit, data, job_id, task_id) +} +\arguments{ +\item{fit}{A list containing the model fit object from \code{epinow2}, which +includes \code{estimates$fit}.} + +\item{data}{A data frame containing the input data used in the model fit.} + +\item{job_id}{A unique identifier for the job or task being processed.} + +\item{task_id}{A unique identifier for the task being performed.} +} +\value{ +A \code{data.frame} containing the extracted diagnostic metrics. The +data frame includes the following columns: +\itemize{ +\item \code{diagnostic}: The name of the diagnostic metric. +\item \code{value}: The value of the diagnostic metric. +\item \code{state}: The state for which the model was run. +\item \code{disease}: The disease/pathogen being analyzed. +\item \code{job_id}: The unique identifier for the job. +\item \code{task_id}: The unique identifier for the task. +} +} +\description{ +This function extracts various diagnostic metrics from a fitted \code{epinow2} +model and provided data. It checks for low case counts and computes +diagnostics from the fitted model, including the mean acceptance +statistic, divergent transitions, maximum tree depth, and Rhat values. +These diagnostics are then flagged if they exceed specific thresholds, +and the results are returned as a data frame. +} +\details{ +The following diagnostics are calculated: +\itemize{ +\item \code{mean_accept_stat}: The average acceptance statistic across +all chains. +\item \code{p_divergent}: The proportion of divergent transitions across +all samples. +\item \code{p_max_treedepth}: The proportion of samples that hit the +maximum tree depth. +\item \code{p_high_rhat}: The proportion of parameters with Rhat values +greater than 1.05, indicating potential convergence issues. +\item \code{low_case_count_flag}: A flag indicating if there are low case +counts in the data. See \code{low_case_count_diagnostic()} for more +information on this diagnostic. +\item \code{epinow2_diagnostic_flag}: A combined flag that indicates if +any diagnostic thresholds are exceeded. +} +} diff --git a/man/low_case_count_diagnostic.Rd b/man/low_case_count_diagnostic.Rd new file mode 100644 index 00000000..501f22ea --- /dev/null +++ b/man/low_case_count_diagnostic.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/extract_diagnostics.R +\name{low_case_count_diagnostic} +\alias{low_case_count_diagnostic} +\title{Calculate low case count diagnostic flag} +\usage{ +low_case_count_diagnostic(df) +} +\arguments{ +\item{df}{A dataframe as returned by \code{\link[=read_data]{read_data()}}. The dataframe must +include columns such as \code{reference_date} (a date vector) and \code{confirm} +(the number of confirmed cases per day).} +} +\value{ +A logical value (TRUE or FALSE) indicating whether either of the last +two weeks in the dataset had fewer than 10 cases per week. +} +\description{ +The diagnostic flag is TRUE if either of the \emph{last} two weeks of the dataset +have fewer than an aggregate 10 cases per week. This aggregation excludes the +count from confirmed outliers, which have been set to NA in the data. +} +\details{ +This function assumes that the \code{epinow2_df} input dataset has been +"completed": that any implicit missingness has been made explicit. +} diff --git a/tests/testthat/data/sample_fit.RDS b/tests/testthat/data/sample_fit.RDS new file mode 100644 index 00000000..7a137f7e Binary files /dev/null and b/tests/testthat/data/sample_fit.RDS differ diff --git a/tests/testthat/test-extract_diagnostics.R b/tests/testthat/test-extract_diagnostics.R new file mode 100644 index 00000000..14baa66b --- /dev/null +++ b/tests/testthat/test-extract_diagnostics.R @@ -0,0 +1,143 @@ +test_that("Fitted model extracts diagnostics", { + # Arrange + data_path <- test_path("data/test_data.parquet") + con <- DBI::dbConnect(duckdb::duckdb()) + expected <- DBI::dbGetQuery(con, " + SELECT + report_date, + reference_date, + disease, + geo_value AS state_abb, + value AS confirm + FROM read_parquet(?) + WHERE reference_date <= '2023-01-22'", + params = list(data_path) + ) + DBI::dbDisconnect(con) + fit_path <- test_path("data", "sample_fit.RDS") + fit <- readRDS(fit_path) + + # Expected diagnostics + expected <- data.frame( + diagnostic = c( + "mean_accept_stat", + "p_divergent", + "p_max_treedepth", + "p_high_rhat", + "diagnostic_flag", + "low_case_count_flag" + ), + value = c( + 0.94240233, + 0.00000000, + 0.00000000, + 0.00000000, + 0.00000000, + 1.00000000 + ), + job_id = rep("test", 6), + task_id = rep("test", 6), + stringsAsFactors = FALSE + ) + actual <- extract_diagnostics(fit, data, "test", "test") + + testthat::expect_equal( + actual, + expected + ) +}) + +test_that("Cases below threshold returns TRUE", { + # Arrange + true_df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + confirm = c(9, rep(0, 12), 9) + ) + + # Act + diagnostic <- low_case_count_diagnostic(true_df) + + # Assert + expect_true(diagnostic) +}) + +test_that("Cases above threshold returns FALSE", { + # Arrange + false_df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + confirm = rep(10, 14) + ) + + # Act + diagnostic <- low_case_count_diagnostic(false_df) + + # Assert + expect_false(diagnostic) +}) + + +test_that("Only the last two weeks are evalated", { + # Arrange + # 3 weeks, first week would pass but last week does not + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 21 + ), + # Week 1: 700, Week 2: 700, Week 3: 0 + confirm = c(rep(100, 14), rep(0, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +}) + +test_that("Old approach's negative is now positive", { + # Arrange + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + # Week 1: 21, Week 2: 0 + confirm = c(rep(3, 7), rep(0, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +}) + +test_that("NAs are evalated as 0", { + # Arrange + df <- data.frame( + reference_date = seq.Date( + from = as.Date("2023-01-01"), + by = "day", + length.out = 14 + ), + # Week 1: 6 (not NA!), Week 2: 700 + confirm = c(NA_real_, rep(1, 6), rep(100, 7)) + ) + + # Act + diagnostic <- low_case_count_diagnostic(df) + + # Assert + expect_true(diagnostic) +})