From cdc6e249dfd1602985e4ab68f4b75577c6feddfd Mon Sep 17 00:00:00 2001 From: Anthony Sena Date: Wed, 5 Jun 2024 18:28:21 -0400 Subject: [PATCH] Prevent upload of inclusion rule names (#156) * Isolate cohort inclusion rule function * Export inclusion rule stats without the need to insert to db --- NAMESPACE | 1 + R/CohortStats.R | 114 +++++++++++-------- R/Export.R | 89 ++++++++------- man/exportCohortStatsTables.Rd | 20 +++- man/getCohortInclusionRules.Rd | 21 ++++ tests/testthat/test-Export.R | 200 +++++++++++++++++++++------------ 6 files changed, 284 insertions(+), 161 deletions(-) create mode 100644 man/getCohortInclusionRules.Rd diff --git a/NAMESPACE b/NAMESPACE index becbe95..e5a3448 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -23,6 +23,7 @@ export(generateCohortSet) export(generateNegativeControlOutcomeCohorts) export(getCohortCounts) export(getCohortDefinitionSet) +export(getCohortInclusionRules) export(getCohortStats) export(getCohortTableNames) export(getRequiredTasks) diff --git a/R/CohortStats.R b/R/CohortStats.R index 9b1c2b5..e3ce475 100644 --- a/R/CohortStats.R +++ b/R/CohortStats.R @@ -47,14 +47,6 @@ insertInclusionRuleNames <- function(connectionDetails = NULL, stop("You must provide either a database connection or the connection details.") } - checkmate::assertDataFrame(cohortDefinitionSet, min.rows = 1, col.names = "named") - checkmate::assertNames(colnames(cohortDefinitionSet), - must.include = c( - "cohortId", - "cohortName", - "json" - ) - ) if (is.null(connection)) { connection <- DatabaseConnector::connect(connectionDetails) on.exit(DatabaseConnector::disconnect(connection)) @@ -65,45 +57,8 @@ insertInclusionRuleNames <- function(connectionDetails = NULL, stop(paste0(cohortInclusionTable, " table not found in schema: ", cohortDatabaseSchema, ". Please make sure the table is created using the createCohortTables() function before calling this function.")) } - # Assemble the cohort inclusion rules - # NOTE: This data frame must match the @cohort_inclusion_table - # structure as defined in inst/sql/sql_server/CreateCohortTables.sql - inclusionRules <- data.frame( - cohortDefinitionId = bit64::integer64(), - ruleSequence = integer(), - name = character(), - description = character() - ) - # Remove any cohort definitions that do not include the JSON property - cohortDefinitionSet <- cohortDefinitionSet[!(is.null(cohortDefinitionSet$json) | is.na(cohortDefinitionSet$json)), ] - for (i in 1:nrow(cohortDefinitionSet)) { - cohortDefinition <- RJSONIO::fromJSON(content = cohortDefinitionSet$json[i], digits = 23) - if (!is.null(cohortDefinition$InclusionRules)) { - nrOfRules <- length(cohortDefinition$InclusionRules) - if (nrOfRules > 0) { - for (j in 1:nrOfRules) { - ruleName <- cohortDefinition$InclusionRules[[j]]$name - ruleDescription <- cohortDefinition$InclusionRules[[j]]$description - if (is.na(ruleName) || ruleName == "") { - ruleName <- paste0("Unamed rule (Sequence ", j - 1, ")") - } - if (is.null(ruleDescription)) { - ruleDescription <- "" - } - inclusionRules <- rbind( - inclusionRules, - data.frame( - cohortDefinitionId = bit64::as.integer64(cohortDefinitionSet$cohortId[i]), - ruleSequence = as.integer(j - 1), - name = ruleName, - description = ruleDescription - ) - ) - } - } - } - } - + inclusionRules <- getCohortInclusionRules(cohortDefinitionSet) + # Remove any existing data to prevent duplication DatabaseConnector::renderTranslateExecuteSql( connection = connection, @@ -174,6 +129,7 @@ getStatsTable <- function(connectionDetails, } #' Get Cohort Inclusion Stats Table Data +#' #' @description #' This function returns a data frame of the data in the Cohort Inclusion Tables. #' Results are organized in to a list with 5 different data frames: @@ -244,3 +200,67 @@ getCohortStats <- function(connectionDetails, } return(results) } + + +#' Get Cohort Inclusion Rules from a cohort definition set +#' +#' @description +#' This function returns a data frame of the inclusion rules defined +#' in a cohort definition set. +#' +#' @md +#' @template CohortDefinitionSet +#' +#' @export +getCohortInclusionRules <- function(cohortDefinitionSet) { + checkmate::assertDataFrame(cohortDefinitionSet, min.rows = 1, col.names = "named") + checkmate::assertNames(colnames(cohortDefinitionSet), + must.include = c( + "cohortId", + "cohortName", + "json" + ) + ) + + # Assemble the cohort inclusion rules + # NOTE: This data frame must match the @cohort_inclusion_table + # structure as defined in inst/sql/sql_server/CreateCohortTables.sql + inclusionRules <- data.frame( + cohortDefinitionId = bit64::integer64(), + ruleSequence = integer(), + name = character(), + description = character() + ) + + # Remove any cohort definitions that do not include the JSON property + cohortDefinitionSet <- cohortDefinitionSet[!(is.null(cohortDefinitionSet$json) | is.na(cohortDefinitionSet$json)), ] + for (i in 1:nrow(cohortDefinitionSet)) { + cohortDefinition <- RJSONIO::fromJSON(content = cohortDefinitionSet$json[i], digits = 23) + if (!is.null(cohortDefinition$InclusionRules)) { + nrOfRules <- length(cohortDefinition$InclusionRules) + if (nrOfRules > 0) { + for (j in 1:nrOfRules) { + ruleName <- cohortDefinition$InclusionRules[[j]]$name + ruleDescription <- cohortDefinition$InclusionRules[[j]]$description + if (is.na(ruleName) || ruleName == "") { + ruleName <- paste0("Unamed rule (Sequence ", j - 1, ")") + } + if (is.null(ruleDescription)) { + ruleDescription <- "" + } + inclusionRules <- rbind( + inclusionRules, + data.frame( + cohortDefinitionId = bit64::as.integer64(cohortDefinitionSet$cohortId[i]), + ruleSequence = as.integer(j - 1), + name = ruleName, + description = ruleDescription + ) + ) + } + } + } + } + + invisible(inclusionRules) +} \ No newline at end of file diff --git a/R/Export.R b/R/Export.R index f351821..b0c013b 100644 --- a/R/Export.R +++ b/R/Export.R @@ -19,7 +19,14 @@ #' @description #' This function retrieves the data from the cohort statistics tables and #' writes them to the inclusion statistics folder specified in the function -#' call. +#' call. NOTE: inclusion rule names are handled in one of two ways: +#' +#' 1. You can specify the cohortDefinitionSet parameter and the inclusion rule +#' names will be extracted from the data.frame. +#' 2. You can insert the inclusion rule names into the database using the +#' insertInclusionRuleNames function of this package. +#' +#' The first approach is preferred as to avoid the warning emitted. #' #' @template Connection #' @@ -38,6 +45,8 @@ #' #' @param databaseId Optional - when specified, the databaseId will be added #' to the exported results +#' +#' @template CohortDefinitionSet #' #' @export exportCohortStatsTables <- function(connectionDetails, @@ -48,7 +57,8 @@ exportCohortStatsTables <- function(connectionDetails, snakeCaseToCamelCase = TRUE, fileNamesInSnakeCase = FALSE, incremental = FALSE, - databaseId = NULL) { + databaseId = NULL, + cohortDefinitionSet = NULL) { if (is.null(connection)) { # Establish the connection and ensure the cleanup is performed connection <- DatabaseConnector::connect(connectionDetails) @@ -58,20 +68,10 @@ exportCohortStatsTables <- function(connectionDetails, if (!dir.exists(cohortStatisticsFolder)) { dir.create(cohortStatisticsFolder, recursive = TRUE) } - - # Export the stats - exportStats <- function(table, - fileName, - includeDatabaseId) { - data <- getStatsTable( - connection = connection, - table = table, - snakeCaseToCamelCase = snakeCaseToCamelCase, - databaseId = databaseId, - cohortDatabaseSchema = cohortDatabaseSchema, - includeDatabaseId = includeDatabaseId - ) - + + # Internal function to export the stats + exportStats <- function(data, + fileName) { fullFileName <- file.path(cohortStatisticsFolder, fileName) ParallelLogger::logInfo("- Saving data to - ", fullFileName) if (incremental) { @@ -86,41 +86,44 @@ exportCohortStatsTables <- function(connectionDetails, .writeCsv(x = data, file = fullFileName) } } - + tablesToExport <- data.frame( - tableName = cohortTableNames$cohortInclusionTable, - fileName = "cohort_inclusion.csv", - includeDatabaseId = FALSE + tableName = c("cohortInclusionResultTable", "cohortInclusionStatsTable", "cohortSummaryStatsTable", "cohortCensorStatsTable"), + fileName = c("cohort_inc_result.csv", "cohort_inc_stats.csv", "cohort_summary_stats.csv", "cohort_censor_stats.csv") + ) + + if (is.null(cohortDefinitionSet)) { + warning("No cohortDefinitionSet specified; please make sure you've inserted the inclusion rule names using the insertInclusionRuleNames function.") + tablesToExport <- rbind(tablesToExport, data.frame( + tableName = "cohortInclusionTable", + fileName = "cohort_inclusion.csv" + )) + } else { + inclusionRules <- getCohortInclusionRules(cohortDefinitionSet) + exportStats( + data = inclusionRules, + fileName = "cohort_inclusion.csv" + ) + } + + # Get the cohort statistics + cohortStats <- getCohortStats( + connectionDetails = connectionDetails, + connection = connection, + cohortDatabaseSchema = cohortDatabaseSchema, + databaseId = databaseId, + snakeCaseToCamelCase = snakeCaseToCamelCase, + cohortTableName = cohortTableNames ) - tablesToExport <- rbind(tablesToExport, data.frame( - tableName = cohortTableNames$cohortInclusionResultTable, - fileName = "cohort_inc_result.csv", - includeDatabaseId = TRUE - )) - tablesToExport <- rbind(tablesToExport, data.frame( - tableName = cohortTableNames$cohortInclusionStatsTable, - fileName = "cohort_inc_stats.csv", - includeDatabaseId = TRUE - )) - tablesToExport <- rbind(tablesToExport, data.frame( - tableName = cohortTableNames$cohortSummaryStatsTable, - fileName = "cohort_summary_stats.csv", - includeDatabaseId = TRUE - )) - tablesToExport <- rbind(tablesToExport, data.frame( - tableName = cohortTableNames$cohortCensorStatsTable, - fileName = "cohort_censor_stats.csv", - includeDatabaseId = TRUE - )) + for (i in 1:nrow(tablesToExport)) { fileName <- ifelse(test = fileNamesInSnakeCase, yes = tablesToExport$fileName[i], no = SqlRender::snakeCaseToCamelCase(tablesToExport$fileName[i]) ) exportStats( - table = tablesToExport$tableName[i], - fileName = fileName, - includeDatabaseId = tablesToExport$includeDatabaseId[i] + data = cohortStats[[tablesToExport$tableName[i]]], + fileName = fileName ) } } diff --git a/man/exportCohortStatsTables.Rd b/man/exportCohortStatsTables.Rd index e4531fe..755dce5 100644 --- a/man/exportCohortStatsTables.Rd +++ b/man/exportCohortStatsTables.Rd @@ -13,7 +13,8 @@ exportCohortStatsTables( snakeCaseToCamelCase = TRUE, fileNamesInSnakeCase = FALSE, incremental = FALSE, - databaseId = NULL + databaseId = NULL, + cohortDefinitionSet = NULL ) } \arguments{ @@ -48,9 +49,24 @@ overwriting an existing results} \item{databaseId}{Optional - when specified, the databaseId will be added to the exported results} + +\item{cohortDefinitionSet}{The \code{cohortDefinitionSet} argument must be a data frame with +the following columns: \describe{ +\item{cohortId}{The unique integer identifier of the cohort} +\item{cohortName}{The cohort's name} +\item{sql}{The OHDSI-SQL used to generate the cohort}} +Optionally, this data frame may contain: \describe{ +\item{json}{The Circe JSON representation of the cohort}}} } \description{ This function retrieves the data from the cohort statistics tables and writes them to the inclusion statistics folder specified in the function -call. +call. NOTE: inclusion rule names are handled in one of two ways: + +1. You can specify the cohortDefinitionSet parameter and the inclusion rule +names will be extracted from the data.frame. +2. You can insert the inclusion rule names into the database using the +insertInclusionRuleNames function of this package. + +The first approach is preferred as to avoid the warning emitted. } diff --git a/man/getCohortInclusionRules.Rd b/man/getCohortInclusionRules.Rd new file mode 100644 index 0000000..cfb9589 --- /dev/null +++ b/man/getCohortInclusionRules.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/CohortStats.R +\name{getCohortInclusionRules} +\alias{getCohortInclusionRules} +\title{Get Cohort Inclusion Rules from a cohort definition set} +\usage{ +getCohortInclusionRules(cohortDefinitionSet) +} +\arguments{ +\item{cohortDefinitionSet}{The \code{cohortDefinitionSet} argument must be a data frame with +the following columns: \describe{ +\item{cohortId}{The unique integer identifier of the cohort} +\item{cohortName}{The cohort's name} +\item{sql}{The OHDSI-SQL used to generate the cohort}} +Optionally, this data frame may contain: \describe{ +\item{json}{The Circe JSON representation of the cohort}}} +} +\description{ +This function returns a data frame of the inclusion rules defined +in a cohort definition set. +} diff --git a/tests/testthat/test-Export.R b/tests/testthat/test-Export.R index bf1eaa4..59441d6 100644 --- a/tests/testthat/test-Export.R +++ b/tests/testthat/test-Export.R @@ -57,14 +57,17 @@ test_that("Export cohort stats with permanent tables", { ) checkmate::expect_names(names(cohortStats), subset.of = c("cohortInclusionStatsTable")) - # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - incremental = FALSE - ) + + expect_warning( + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + incremental = FALSE + ) + ) # Verify the files are written to the file system exportedFiles <- list.files(path = cohortStatsFolder, pattern = "*.csv") @@ -102,13 +105,15 @@ test_that("Export cohort stats with databaseId", { ) # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - incremental = FALSE, - databaseId = "Eunomia" + expect_warning( + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + incremental = FALSE, + databaseId = "Eunomia" + ) ) # Verify the files are written to the file system and have the database_id @@ -146,15 +151,17 @@ test_that("Export cohort stats with fileNamesInSnakeCase = TRUE", { incremental = FALSE ) - # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - fileNamesInSnakeCase = TRUE, - incremental = FALSE, - databaseId = "Eunomia" + expect_warning( + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + fileNamesInSnakeCase = TRUE, + incremental = FALSE, + databaseId = "Eunomia" + ) ) # Verify the files are written to the file system and are in snake_case @@ -176,13 +183,15 @@ test_that("Export cohort stats in incremental mode", { cohortTableNames = cohortTableNames ) - # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - incremental = TRUE + expect_warning( + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + incremental = TRUE + ) ) # Verify the files are written to the file system @@ -212,15 +221,17 @@ test_that("Export cohort stats with camelCase for column names", { cohortDefinitionSet = cohortsWithStats ) - # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - snakeCaseToCamelCase = TRUE, - fileNamesInSnakeCase = TRUE, - incremental = TRUE + expect_warning( + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + snakeCaseToCamelCase = TRUE, + fileNamesInSnakeCase = TRUE, + incremental = TRUE + ) ) # Verify the files are written to the file system and the columns are in @@ -231,16 +242,18 @@ test_that("Export cohort stats with camelCase for column names", { expect_true(all(isCamelCase(names(data)))) } - # Export the results again in incremental mode and verify - # the results are preserved - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - snakeCaseToCamelCase = TRUE, - fileNamesInSnakeCase = TRUE, - incremental = TRUE + expect_warning( + # Export the results again in incremental mode and verify + # the results are preserved + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + snakeCaseToCamelCase = TRUE, + fileNamesInSnakeCase = TRUE, + incremental = TRUE + ) ) # Verify the cohort_inc_stats.csv contains cohortDefinitionIds c(2,3) @@ -273,15 +286,17 @@ test_that("Export cohort stats with snake_case for column names", { cohortDefinitionSet = cohortsWithStats ) - # Export the results - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - snakeCaseToCamelCase = FALSE, - fileNamesInSnakeCase = TRUE, - incremental = TRUE + expect_warning( + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + snakeCaseToCamelCase = FALSE, + fileNamesInSnakeCase = TRUE, + incremental = TRUE + ) ) # Verify the files are written to the file system and the columns are in @@ -292,16 +307,18 @@ test_that("Export cohort stats with snake_case for column names", { expect_true(all(isSnakeCase(names(data)))) } - # Export the results again in incremental mode and verify - # the results are preserved - exportCohortStatsTables( - connectionDetails = connectionDetails, - cohortDatabaseSchema = "main", - cohortTableNames = cohortTableNames, - cohortStatisticsFolder = cohortStatsFolder, - snakeCaseToCamelCase = FALSE, - fileNamesInSnakeCase = TRUE, - incremental = TRUE + expect_warning( + # Export the results again in incremental mode and verify + # the results are preserved + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + snakeCaseToCamelCase = FALSE, + fileNamesInSnakeCase = TRUE, + incremental = TRUE + ) ) # Verify the cohort_inc_stats.csv contains cohort_definition_id == c(2,3) @@ -312,3 +329,48 @@ test_that("Export cohort stats with snake_case for column names", { expect_equal(unique(data$cohort_definition_id), c(2, 3)) unlink(cohortStatsFolder) }) + +test_that("Export cohort stats using cohortDefinitionSet for inclusion rule names", { + cohortTableNames <- getCohortTableNames(cohortTable = "cohortStatsInclRule") + cohortStatsFolder <- file.path(outputFolder, "stats") + # First create the cohort tables + createCohortTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames + ) + + # Generate with stats + cohortsWithStats <- getCohortsForTest(cohorts, generateStats = TRUE) + generateCohortSet( + connectionDetails = connectionDetails, + cohortDefinitionSet = cohortsWithStats, + cdmDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortDatabaseSchema = "main", + incremental = FALSE + ) + + # Export the results + exportCohortStatsTables( + connectionDetails = connectionDetails, + cohortDatabaseSchema = "main", + cohortTableNames = cohortTableNames, + cohortStatisticsFolder = cohortStatsFolder, + incremental = FALSE, + databaseId = "Eunomia", + cohortDefinitionSet = cohortsWithStats + ) + + # Verify the files are written to the file system and that + # the cohort inclusion information has been written + exportedFiles <- list.files(path = cohortStatsFolder, pattern = ".csv", full.names = TRUE) + expect_true("cohortInclusion.csv" %in% basename(exportedFiles)) + for (i in 1:length(exportedFiles)) { + if (basename(exportedFiles[i]) == "cohortInclusion.csv") { + data <- CohortGenerator:::.readCsv(file = exportedFiles[i]) + expect_true(nrow(data) > 0) + } + } + unlink(cohortStatsFolder) +}) \ No newline at end of file