Skip to content

Commit

Permalink
Set up diagnostics
Browse files Browse the repository at this point in the history
  • Loading branch information
zsusswein committed Sep 16, 2024
1 parent cf892f9 commit b329b8f
Show file tree
Hide file tree
Showing 8 changed files with 381 additions and 104 deletions.
148 changes: 76 additions & 72 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,99 +1,103 @@
# All available hooks: https://pre-commit.com/hooks.html
# R specific hooks: https://github.com/lorenzwalthert/precommit
repos:
# R
- repo: https://github.com/lorenzwalthert/precommit
# R
- repo: https://github.com/lorenzwalthert/precommit
rev: v0.4.3
hooks:
- id: style-files
args: [--style_pkg=styler, --style_fun=tidyverse_style,
--cache-root=styler-perm]
- id: use-tidy-description
- id: lintr
- id: readme-rmd-rendered
- id: parsable-R
- id: no-browser-statement
- id: no-print-statement
- id: style-files
args:
[
--style_pkg=styler,
--style_fun=tidyverse_style,
--cache-root=styler-perm,
]
- id: use-tidy-description
- id: lintr
- id: readme-rmd-rendered
- id: parsable-R
- id: no-browser-statement
- id: no-print-statement
exclude: '^tests/testthat/test-print\.R$'
- id: no-debug-statement
- id: deps-in-desc
- repo: https://github.com/pre-commit/pre-commit-hooks
- id: no-debug-statement
- id: deps-in-desc
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-added-large-files
args: ['--maxkb=200']
- id: file-contents-sorter
- id: check-added-large-files
args: ["--maxkb=200"]
exclude: "tests/testthat/data/sample_fit.RDS"
- id: file-contents-sorter
files: '^\.Rbuildignore$'
- id: end-of-file-fixer
- id: end-of-file-fixer
exclude: '(\.Rd)|(tests/testthat/_snaps/)'
- id: check-yaml
- id: check-toml
- id: mixed-line-ending
args: ['--fix=lf']
- id: trailing-whitespace
exclude: 'tests/testthat/_snaps/'
- repo: https://github.com/pre-commit-ci/pre-commit-ci-config
- id: check-yaml
- id: check-toml
- id: mixed-line-ending
args: ["--fix=lf"]
- id: trailing-whitespace
exclude: "tests/testthat/_snaps/"
- repo: https://github.com/pre-commit-ci/pre-commit-ci-config
rev: v1.6.1
hooks:
# Only required when https://pre-commit.ci is used for config validation
- id: check-pre-commit-ci-config
- repo: local
# Only required when https://pre-commit.ci is used for config validation
- id: check-pre-commit-ci-config
- repo: local
hooks:
- id: forbid-to-commit
- id: forbid-to-commit
name: Don't commit common R artifacts
entry: Cannot commit .Rhistory, .RData, .Rds or .rds.
language: fail
files: '\.(Rhistory|RData|Rds|rds)$'
# `exclude: <regex>` to allow committing specific files
#####
# Python
- repo: https://github.com/psf/black
#####
# Python
- repo: https://github.com/psf/black
rev: 24.8.0
hooks:
# if you have ipython notebooks, consider using
# `black-jupyter` hook instead
- id: black
args: ['--line-length', '79']
- repo: https://github.com/PyCQA/isort
# if you have ipython notebooks, consider using
# `black-jupyter` hook instead
- id: black
args: ["--line-length", "79"]
- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
args: ['--profile', 'black',
'--line-length', '79']
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff
#####
# Java
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.14.0
hooks:
- id: pretty-format-java
args: [--aosp,--autofix]
#####
# Julia
# Due to lack of first-class Julia support, this needs Julia local install
# and JuliaFormatter.jl installed in the library
# - repo: https://github.com/domluna/JuliaFormatter.jl
# rev: v1.0.39
# hooks:
# - id: julia-formatter
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
- id: isort
args: ["--profile", "black", "--line-length", "79"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.6.4
hooks:
- id: ruff
#####
# Java
- repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks
rev: v2.14.0
hooks:
- id: pretty-format-java
args: [--aosp, --autofix]
#####
# Julia
# Due to lack of first-class Julia support, this needs Julia local install
# and JuliaFormatter.jl installed in the library
# - repo: https://github.com/domluna/JuliaFormatter.jl
# rev: v1.0.39
# hooks:
# - id: julia-formatter
#####
# Secrets
- repo: https://github.com/Yelp/detect-secrets
rev: v1.5.0
hooks:
- id: detect-secrets
args: ['--baseline', '.secrets.baseline']
- id: detect-secrets
args: ["--baseline", ".secrets.baseline"]
exclude: package.lock.json
ci:
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks
autofix_commit_msg: |
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ''
autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
autoupdate_schedule: weekly
submodules: false
for more information, see https://pre-commit.ci
autofix_prs: true
autoupdate_branch: ""
autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate"
autoupdate_schedule: weekly
submodules: false
7 changes: 0 additions & 7 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ Imports:
EpiNow2 (>= 1.4.0),
rlang,
rstan
LinkingTo:
BH (>= 1.66.0),
Rcpp (>= 0.12.0),
RcppEigen (>= 0.3.3.3.0),
RcppParallel (>= 5.0.1),
rstan (>= 2.26.0),
StanHeaders (>= 2.26.0)
Additional_repositories:
https://mc-stan.org/r-packages/
URL: https://cdcgov.github.io/cfa-epinow2-pipeline/
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

export(apply_exclusions)
export(download_from_azure_blob)
export(extract_diagnostics)
export(fetch_blob_container)
export(fetch_credential_from_env_var)
export(fit_model)
export(format_delay_interval)
export(format_generation_interval)
export(format_right_truncation)
export(low_case_count_diagnostic)
export(read_data)
export(read_disease_parameters)
export(read_exclusions)
Expand Down
103 changes: 78 additions & 25 deletions R/extract_diagnostics.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,54 @@
#' Extract diagnostic metrics from model fit and data
#'
#' This function extracts various diagnostic metrics from a fitted `epinow2`
#' model and provided data. It checks for low case counts and computes
#' diagnostics from the fitted model, including the mean acceptance
#' statistic, divergent transitions, maximum tree depth, and Rhat values.
#' These diagnostics are then flagged if they exceed specific thresholds,
#' and the results are returned as a data frame.
#'
#' @param fit A list containing the model fit object from `epinow2`, which
#' includes `estimates$fit`.
#' @param data A data frame containing the input data used in the model fit.
#' @param job_id A unique identifier for the job or task being processed.
#' @param task_id A unique identifier for the task being performed.
#'
#'
#' @return A \code{data.frame} containing the extracted diagnostic metrics. The
#' data frame includes the following columns:
#' \itemize{
#' \item \code{diagnostic}: The name of the diagnostic metric.
#' \item \code{value}: The value of the diagnostic metric.
#' \item \code{state}: The state for which the model was run.
#' \item \code{disease}: The disease/pathogen being analyzed.
#' \item \code{job_id}: The unique identifier for the job.
#' \item \code{task_id}: The unique identifier for the task.
#' }
#'
#' @details
#' The following diagnostics are calculated:
#' \itemize{
#' \item \code{mean_accept_stat}: The average acceptance statistic across
#' all chains.
#' \item \code{p_divergent}: The proportion of divergent transitions across
#' all samples.
#' \item \code{p_max_treedepth}: The proportion of samples that hit the
#' maximum tree depth.
#' \item \code{p_high_rhat}: The proportion of parameters with Rhat values
#' greater than 1.05, indicating potential convergence issues.
#' \item \code{low_case_count_flag}: A flag indicating if there are low case
#' counts in the data. See \code{low_case_count_diagnostic()} for more
#' information on this diagnostic.
#' \item \code{epinow2_diagnostic_flag}: A combined flag that indicates if
#' any diagnostic thresholds are exceeded.
#' }
#' @export
extract_diagnostics <- function(fit, data, job_id, task_id) {
low_case_count <- low_case_count_diagnostic(data)

epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit)
epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit,
inc_warmup = FALSE
)
mean_accept_stat <- mean(
sapply(epinow2_diagnostics, function(x) mean(x[, "accept_stat__"]))
)
Expand All @@ -17,37 +64,40 @@ extract_diagnostics <- function(fit, data, job_id, task_id) {
rstan::summary(fit$estimates$fit)$summary[, "Rhat"] > 1.05,
na.rm = TRUE
)
any_diagnostic_flag <- any(

# Combine all diagnostic flags into one flag
diagnostic_flag <- any(
mean_accept_stat < 0.1,
p_divergent > 0.0075,
p_max_treedepth > 0.05,
p_high_rhat > 0.0075
)
diagnostic_df <- data.frame(
diagnostic = c(
"mean_accept_stat",
"p_divergent",
"p_max_treedepth",
"p_high_rhat",
"epinow2_diagnostic_flag",
"low_case_count_flag"
),
value = c(
mean_accept_stat,
p_divergent,
p_max_treedepth,
p_high_rhat,
epinow2_diagnostic_flag,
low_case_count
),
"state" = state,
disease = pathogen,
# Create individual vectors for the columns of the diagnostics data frame
diagnostic_names <- c(
"mean_accept_stat",
"p_divergent",
"p_max_treedepth",
"p_high_rhat",
"diagnostic_flag",
"low_case_count_flag"
)
diagnostic_values <- c(
mean_accept_stat,
p_divergent,
p_max_treedepth,
p_high_rhat,
diagnostic_flag,
low_case_count
)

return(bind_rows(epinow2_df, data_df))
data.frame(
diagnostic = diagnostic_names,
value = diagnostic_values,
job_id = job_id,
task_id = task_id
)
}


#' Calculate low case count diagnostic flag
#'
#' The diagnostic flag is TRUE if either of the _last_ two weeks of the dataset
Expand All @@ -57,9 +107,12 @@ extract_diagnostics <- function(fit, data, job_id, task_id) {
#' This function assumes that the `epinow2_df` input dataset has been
#' "completed": that any implicit missingness has been made explicit.
#'
#' @param df A dataframe as returned by [read_data()]
#' @param df A dataframe as returned by [read_data()]. The dataframe must
#' include columns such as `reference_date` (a date vector) and `confirm`
#' (the number of confirmed cases per day).
#'
#' @return
#' @return A logical value (TRUE or FALSE) indicating whether either of the last
#' two weeks in the dataset had fewer than 10 cases per week.
#' @export
low_case_count_diagnostic <- function(df) {
# Get the dates in the last and second-to-last weeks
Expand Down
56 changes: 56 additions & 0 deletions man/extract_diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit b329b8f

Please sign in to comment.