diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 307d747e..c8b1106e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,6 +44,7 @@ repos: entry: Cannot commit .Rhistory, .RData, .Rds or .rds. language: fail files: '\.(Rhistory|RData|Rds|rds)$' + exclude: '^tests/testthat/data/.*\.rds$' # `exclude: ` to allow committing specific files ##### # Python diff --git a/R/extract_diagnostics.R b/R/extract_diagnostics.R index 2ff52f55..c215ee25 100644 --- a/R/extract_diagnostics.R +++ b/R/extract_diagnostics.R @@ -12,7 +12,7 @@ #' @param data A data frame containing the input data used in the model fit. #' @param job_id A unique identifier for the job or task being processed. #' @param task_id A unique identifier for the task being performed. -#' +#' @param disease,geo_value,model Metadata for downstream processing. #' #' @return A \code{data.frame} containing the extracted diagnostic metrics. The #' data frame includes the following columns: @@ -23,6 +23,7 @@ #' \item \code{disease}: The disease/pathogen being analyzed. #' \item \code{job_id}: The unique identifier for the job. #' \item \code{task_id}: The unique identifier for the task. +#' \item \code{disease,geo_value,model}: Metadata for downstream processing. #' } #' #' @details @@ -43,7 +44,13 @@ #' any diagnostic thresholds are exceeded. #' } #' @export -extract_diagnostics <- function(fit, data, job_id, task_id) { +extract_diagnostics <- function(fit, + data, + job_id, + task_id, + disease, + geo_value, + model) { low_case_count <- low_case_count_diagnostic(data) epinow2_diagnostics <- rstan::get_sampler_params(fit$estimates$fit, @@ -94,7 +101,10 @@ extract_diagnostics <- function(fit, data, job_id, task_id) { diagnostic = diagnostic_names, value = diagnostic_values, job_id = job_id, - task_id = task_id + task_id = task_id, + disease = disease, + geo_value = geo_value, + model = model ) } diff --git a/R/write_output.R b/R/write_output.R index 40ff2f6a..6bd1b847 100644 --- a/R/write_output.R +++ b/R/write_output.R @@ -53,7 +53,7 @@ write_model_outputs <- function( job_id, "tasks", task_id, - "model.RDS" + "model.rds" ) saveRDS(fit, model_path) cli::cli_alert_success("Wrote model to {.path {model_path}}") @@ -193,7 +193,12 @@ extract_draws_from_fit <- function(fit) { #' @return A data.table with merged posterior draws and standardized parameter #' names. #' @noRd -post_process_and_merge <- function(draws, fact_table) { +post_process_and_merge <- function( + draws, + fact_table, + geo_value, + model, + disease) { # Step 1: Left join the date-time-parameter map onto the Stan draws merged_dt <- merge( draws, @@ -218,13 +223,18 @@ post_process_and_merge <- function(draws, fact_table) { ".point", ".interval", "date", ".iteration" ), new = c( - "_draw", "_chain", "_variable", "_value", "_lower", "_upper", "_width", + "_draw", "_chain", "_variable", "value", "_lower", "_upper", "_width", "_point", "_interval", "reference_date", "_iteration" ), # If using summaries, skip draws-specific names skip_absent = TRUE ) + # Metadata for downstream querying without path parsing or joins + data.table::set(merged_dt, j = "geo_value", value = factor(geo_value)) + data.table::set(merged_dt, j = "model", value = factor(model)) + data.table::set(merged_dt, j = "disease", value = factor(disease)) + return(merged_dt) } @@ -236,6 +246,7 @@ post_process_and_merge <- function(draws, fact_table) { #' returned in `{tidybayes}` format. #' #' @param fit An EpiNow2 fit object with posterior estimates. +#' @param disease,geo_value,model Metadata for downstream processing. #' #' @return A data.table of posterior draws or quantiles, merged and processed. #' @@ -244,17 +255,21 @@ NULL #' @rdname sample_processing_functions #' @export -process_samples <- function(fit) { +process_samples <- function(fit, geo_value, model, disease) { draws_list <- extract_draws_from_fit(fit) raw_processed_output <- post_process_and_merge( - draws_list$stan_draws, draws_list$fact_table + draws_list$stan_draws, + draws_list$fact_table, + geo_value, + model, + disease ) return(raw_processed_output) } #' @rdname sample_processing_functions #' @export -process_quantiles <- function(fit) { +process_quantiles <- function(fit, geo_value, model, disease) { # Step 1: Extract the draws draws_list <- extract_draws_from_fit(fit) @@ -268,7 +283,13 @@ process_quantiles <- function(fit) { data.table::as.data.table() # Step 3: Post-process summarized draws - post_process_and_merge(summarized_draws, draws_list$fact_table) + post_process_and_merge( + summarized_draws, + draws_list$fact_table, + geo_value, + model, + disease + ) } write_parquet <- function(data, path) { diff --git a/README.md b/README.md index f3ac7890..b9e13fa5 100644 --- a/README.md +++ b/README.md @@ -39,18 +39,20 @@ This package implements functions for: ## Output format +Outputs are stored in a s + ```bash output/ -├── job_/ +├── / │ ├── raw_samples/ │ │ ├── raw_samples_task_.parquet │ ├── summarized_quantiles/ │ │ ├── summarized_quantiles_task_.parquet │ ├── tasks/ -│ │ ├── task_/ +│ │ ├── / │ │ │ ├── model.stan │ │ │ ├── metadata.json -│ │ │ ├── task.log +│ │ │ ├── logs.txt │ │ │ └── error.log │ ├── job_metadata.json ``` diff --git a/man/extract_diagnostics.Rd b/man/extract_diagnostics.Rd index 5356e71b..22278feb 100644 --- a/man/extract_diagnostics.Rd +++ b/man/extract_diagnostics.Rd @@ -4,7 +4,7 @@ \alias{extract_diagnostics} \title{Extract diagnostic metrics from model fit and data} \usage{ -extract_diagnostics(fit, data, job_id, task_id) +extract_diagnostics(fit, data, job_id, task_id, disease, geo_value, model) } \arguments{ \item{fit}{A list containing the model fit object from \code{epinow2}, which @@ -15,6 +15,8 @@ includes \code{estimates$fit}.} \item{job_id}{A unique identifier for the job or task being processed.} \item{task_id}{A unique identifier for the task being performed.} + +\item{disease, geo_value, model}{Metadata for downstream processing.} } \value{ A \code{data.frame} containing the extracted diagnostic metrics. The @@ -26,6 +28,7 @@ data frame includes the following columns: \item \code{disease}: The disease/pathogen being analyzed. \item \code{job_id}: The unique identifier for the job. \item \code{task_id}: The unique identifier for the task. +\item \code{disease,geo_value,model}: Metadata for downstream processing. } } \description{ diff --git a/man/sample_processing_functions.Rd b/man/sample_processing_functions.Rd index e2019562..95635e92 100644 --- a/man/sample_processing_functions.Rd +++ b/man/sample_processing_functions.Rd @@ -6,12 +6,14 @@ \alias{process_quantiles} \title{Process posterior samples from a Stan fit object (raw draws).} \usage{ -process_samples(fit) +process_samples(fit, geo_value, model, disease) -process_quantiles(fit) +process_quantiles(fit, geo_value, model, disease) } \arguments{ \item{fit}{An EpiNow2 fit object with posterior estimates.} + +\item{disease, geo_value, model}{Metadata for downstream processing.} } \value{ A data.table of posterior draws or quantiles, merged and processed. diff --git a/tests/testthat/data/sample_fit.RDS b/tests/testthat/data/sample_fit.rds similarity index 100% rename from tests/testthat/data/sample_fit.RDS rename to tests/testthat/data/sample_fit.rds diff --git a/tests/testthat/helper-write_parameter_file.R b/tests/testthat/helper-write_parameter_file.R index 387ec9ca..0d7e60a5 100644 --- a/tests/testthat/helper-write_parameter_file.R +++ b/tests/testthat/helper-write_parameter_file.R @@ -6,6 +6,7 @@ write_sample_parameters_file <- function(value, parameter, start_date, end_date) { + Sys.sleep(0.05) df <- data.frame( start_date = as.Date(start_date), geo_value = state, @@ -16,6 +17,7 @@ write_sample_parameters_file <- function(value, ) con <- DBI::dbConnect(duckdb::duckdb()) + on.exit(DBI::dbDisconnect(con)) duckdb::duckdb_register(con, "test_table", df) # This is bad practice but `dbBind()` doesn't allow us to parameterize COPY @@ -24,7 +26,6 @@ write_sample_parameters_file <- function(value, # guard against a SQL injection attack. query <- paste0("COPY (SELECT * FROM test_table) TO '", path, "'") DBI::dbExecute(con, query) - DBI::dbDisconnect(con) invisible(path) } diff --git a/tests/testthat/test-extract_diagnostics.R b/tests/testthat/test-extract_diagnostics.R index 9b728d85..8b032001 100644 --- a/tests/testthat/test-extract_diagnostics.R +++ b/tests/testthat/test-extract_diagnostics.R @@ -14,7 +14,7 @@ test_that("Fitted model extracts diagnostics", { params = list(data_path) ) DBI::dbDisconnect(con) - fit_path <- test_path("data", "sample_fit.RDS") + fit_path <- test_path("data", "sample_fit.rds") fit <- readRDS(fit_path) # Expected diagnostics @@ -37,9 +37,20 @@ test_that("Fitted model extracts diagnostics", { ), job_id = rep("test", 6), task_id = rep("test", 6), + disease = rep("test", 6), + geo_value = rep("test", 6), + model = rep("test", 6), stringsAsFactors = FALSE ) - actual <- extract_diagnostics(fit, data, "test", "test") + actual <- extract_diagnostics( + fit, + data, + "test", + "test", + "test", + "test", + "test" + ) testthat::expect_equal( actual, diff --git a/tests/testthat/test-write_output.R b/tests/testthat/test-write_output.R index 106cb41b..a9ca1e7c 100644 --- a/tests/testthat/test-write_output.R +++ b/tests/testthat/test-write_output.R @@ -45,12 +45,12 @@ test_that("write_model_outputs writes files and directories correctly", { ) expect_true(file.exists(summarized_file)) - # Check if model RDS file was written + # Check if model rds file was written model_file <- file.path( job_id, "tasks", task_id, - "model.RDS" + "model.rds" ) expect_true(file.exists(model_file)) @@ -122,10 +122,10 @@ test_that("write_output_dir_structure generates dirs", { test_that("process_quantiles works as expected", { # Load the sample fit object - fit <- readRDS(test_path("data", "sample_fit.RDS")) + fit <- readRDS(test_path("data", "sample_fit.rds")) # Run the function on the fit object - result <- process_quantiles(fit) + result <- process_quantiles(fit, "test_geo", "test_model", "test_disease") # Test 1: Check if the result is a data.table expect_true( @@ -135,9 +135,18 @@ test_that("process_quantiles works as expected", { # Test 2: Check if the necessary columns exist in the result expected_columns <- c( - "time", "_variable", "_value", - "_lower", "_upper", "_width", - "_point", "_interval", "reference_date" + "time", + "_variable", + "value", + "_lower", + "_upper", + "_width", + "_point", + "_interval", + "reference_date", + "geo_value", + "model", + "disease" ) expect_equal( colnames(result), expected_columns @@ -187,10 +196,10 @@ test_that("process_quantiles works as expected", { test_that("process_samples works as expected", { # Load the sample fit object - fit <- readRDS(test_path("data", "sample_fit.RDS")) + fit <- readRDS(test_path("data", "sample_fit.rds")) # Run the function on the fit object - result <- process_samples(fit) + result <- process_samples(fit, "test_geo", "test_model", "test_disease") # Test 1: Check if the result is a data.table expect_true( @@ -200,9 +209,16 @@ test_that("process_samples works as expected", { # Test 2: Check if the necessary columns exist in the result expected_columns <- c( - "time", "_variable", "_chain", - "_iteration", "_draw", "_value", - "reference_date" + "time", + "_variable", + "_chain", + "_iteration", + "_draw", + "value", + "reference_date", + "geo_value", + "model", + "disease" ) expect_equal( colnames(result), expected_columns