From 8ee5515eda952c3cd0df645230091c6d097ab0be Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Mon, 18 Dec 2023 13:33:14 -0500 Subject: [PATCH 01/13] PSM CLI, table persistence --- cumulus_library/base_table_builder.py | 19 +- cumulus_library/cli.py | 152 ++++++++++------ cumulus_library/cli_parser.py | 28 ++- cumulus_library/databases.py | 8 +- cumulus_library/enums.py | 17 ++ cumulus_library/protected_table_builder.py | 58 ++++++ cumulus_library/statistics/psm.py | 148 +++++++++++++--- cumulus_library/study_parser.py | 167 +++++++++++++++--- .../template_sql/ctas_empty.sql.jinja | 7 +- .../psm_create_covariate_table.sql.jinja | 2 +- .../template_sql/statistics/psm_templates.py | 2 + cumulus_library/template_sql/templates.py | 10 +- docs/statistics/propensity-score-matching.md | 4 + tests/test_data/duckdb_data/duck.db | Bin 12288 -> 0 bytes tests/test_data/psm/manifest.toml | 7 + tests/test_data/psm/psm_cohort.sql | 3 + tests/test_data/psm/psm_config.toml | 7 +- .../test_data/psm/psm_config_no_optional.toml | 6 +- tests/test_psm.py | 2 +- tests/test_templates.py | 43 +++++ 20 files changed, 574 insertions(+), 116 deletions(-) create mode 100644 cumulus_library/enums.py create mode 100644 cumulus_library/protected_table_builder.py delete mode 100644 tests/test_data/duckdb_data/duck.db create mode 100644 tests/test_data/psm/manifest.toml create mode 100644 tests/test_data/psm/psm_cohort.sql diff --git a/cumulus_library/base_table_builder.py b/cumulus_library/base_table_builder.py index 512ae4d2..dfcd41fe 100644 --- a/cumulus_library/base_table_builder.py +++ b/cumulus_library/base_table_builder.py @@ -1,5 +1,6 @@ """ abstract base for python-based study executors """ import re +import sys from abc import ABC, abstractmethod from typing import final @@ -21,7 +22,7 @@ def __init__(self): self.queries = [] @abstractmethod - def prepare_queries(self, cursor: object, schema: str): + def prepare_queries(self, cursor: object, schema: str, *args, **kwargs): """Main entrypoint for python table builders. When completed, prepare_queries should populate self.queries with sql @@ -29,7 +30,8 @@ def prepare_queries(self, cursor: object, schema: str): :param cursor: A PEP-249 compatible cursor :param schema: A schema name - :param verbose: toggle for verbose output mode + :param db_type: The db system being used (only relevant for db-specific + query construction) """ raise NotImplementedError @@ -40,6 +42,8 @@ def execute_queries( schema: str, verbose: bool, drop_table: bool = False, + *args, + **kwargs, ): """Executes queries set up by a prepare_queries call @@ -48,7 +52,7 @@ def execute_queries( :param verbose: toggle for verbose output mode :param drop_table: drops any tables found in prepared_queries results """ - self.prepare_queries(cursor, schema) + self.prepare_queries(cursor, schema, *args, **kwargs) if drop_table: table_names = [] for query in self.queries: @@ -73,8 +77,11 @@ def execute_queries( ) for query in self.queries: query_console_output(verbose, query, progress, task) - cursor.execute(query) - self.post_execution(cursor, schema, verbose, drop_table) + try: + cursor.execute(query) + except Exception as e: + sys.exit(e) + self.post_execution(cursor, schema, verbose, drop_table, *args, **kwargs) def post_execution( self, @@ -82,6 +89,8 @@ def post_execution( schema: str, verbose: bool, drop_table: bool = False, + *args, + **kwargs, ): """Hook for any additional actions to run after execute_queries""" pass diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 523f8e1d..b8d04de0 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -6,6 +6,7 @@ import sys import sysconfig +from datetime import datetime from pathlib import Path, PosixPath from typing import Dict, List, Optional @@ -19,57 +20,50 @@ DatabaseBackend, create_db_backend, ) +from cumulus_library.enums import PROTECTED_TABLES +from cumulus_library.protected_table_builder import TRANSACTIONS_COLS from cumulus_library.study_parser import StudyManifestParser +from cumulus_library.template_sql.templates import get_insert_into_query from cumulus_library.upload import upload_files -# ** Don't delete! ** -# This class isn't used in the rest of the code, -# but it is used manually as a quick & dirty alternative to the CLI. -class CumulusEnv: # pylint: disable=too-few-public-methods - """ - Wrapper for Cumulus Environment vars. - Simplifies connections to StudyBuilder without requiring CLI parsing. - """ - - def __init__(self): - self.region = os.environ.get("CUMULUS_LIBRARY_REGION", "us-east-1") - self.workgroup = os.environ.get("CUMULUS_LIBRARY_WORKGROUP", "cumulus") - self.profile = os.environ.get("CUMULUS_LIBRARY_PROFILE") - self.schema_name = os.environ.get("CUMULUS_LIBRARY_DATABASE") - - def get_study_builder(self): - """Convenience method for getting athena args from environment""" - db = AthenaDatabaseBackend( - self.region, self.workgroup, self.profile, self.schema_name - ) - return StudyBuilder(db) - - class StudyBuilder: """Class for managing Athena cursors and executing Cumulus queries""" verbose = False schema_name = None - def __init__(self, db: DatabaseBackend): + def __init__(self, db: DatabaseBackend, data_path: str): self.db = db + self.data_path = data_path self.cursor = db.cursor() self.schema_name = db.schema_name - def reset_data_path(self, study: PosixPath) -> None: - """ - Removes existing exports from a study's local data dir - """ - project_path = Path(__file__).resolve().parents[1] - path = Path(f"{str(project_path)}/data_export/{study}/") - if path.exists(): - for file in path.glob("*"): - file.unlink() + def update_transactions(self, prefix: str, status: str): + self.cursor.execute( + get_insert_into_query( + f"{prefix}__{PROTECTED_TABLES.TRANSACTIONS.value}", + TRANSACTIONS_COLS, + [ + [ + prefix, + __version__, + status, + datetime.now().replace(microsecond=0).isoformat(), + ] + ], + ) + ) ### Creating studies - def clean_study(self, targets: List[str], study_dict, prefix=False) -> None: + def clean_study( + self, + targets: List[str], + study_dict: Dict, + stats_clean: bool, + prefix: bool = False, + ) -> None: """Removes study table/views from Athena. While this is usually not required, since it it done as part of a build, @@ -86,25 +80,69 @@ def clean_study(self, targets: List[str], study_dict, prefix=False) -> None: if prefix: parser = StudyManifestParser() parser.clean_study( - self.cursor, self.schema_name, self.verbose, prefix=target + self.cursor, + self.schema_name, + verbose=self.verbose, + stats_clean=stats_clean, + prefix=target, ) else: parser = StudyManifestParser(study_dict[target]) - parser.clean_study(self.cursor, self.schema_name, self.verbose) + parser.clean_study( + self.cursor, + self.schema_name, + verbose=self.verbose, + stats_clean=stats_clean, + ) def clean_and_build_study( - self, target: PosixPath, continue_from: str = None + self, + target: PosixPath, + export_dir: PosixPath, + stats_build: bool, + continue_from: str = None, ) -> None: """Recreates study views/tables :param target: A PosixPath to the study directory """ - studyparser = StudyManifestParser(target) - if not continue_from: - studyparser.clean_study(self.cursor, self.schema_name, self.verbose) - studyparser.run_table_builder(self.cursor, self.schema_name, self.verbose) - studyparser.build_study(self.cursor, self.verbose, continue_from) - studyparser.run_counts_builder(self.cursor, self.schema_name, self.verbose) + + studyparser = StudyManifestParser(target, self.data_path) + try: + if not continue_from: + studyparser.run_protected_table_builder( + self.cursor, self.schema_name, verbose=self.verbose + ) + self.update_transactions(studyparser.get_study_prefix(), "started") + cleaned_tables = studyparser.clean_study( + self.cursor, + self.schema_name, + verbose=self.verbose, + stats_clean=False, + ) + # If the study hasn't been created before, force stats table generation + if len(cleaned_tables) == 0: + stats_build = True + studyparser.run_table_builder( + self.cursor, self.schema_name, verbose=self.verbose + ) + else: + self.update_transactions(studyparser.get_study_prefix(), "resumed") + studyparser.build_study(self.cursor, self.verbose, continue_from) + studyparser.run_counts_builders( + self.cursor, self.schema_name, verbose=self.verbose + ) + studyparser.run_statistics_builders( + self.cursor, + self.schema_name, + verbose=self.verbose, + stats_build=stats_build, + data_path=self.data_path, + ) + self.update_transactions(studyparser.get_study_prefix(), "finished") + except Exception as e: + self.update_transactions(studyparser.get_study_prefix(), "error") + raise e def run_single_table_builder( self, target: PosixPath, table_builder_name: str @@ -118,7 +156,7 @@ def run_single_table_builder( self.cursor, self.schema_name, table_builder_name, self.verbose ) - def clean_and_build_all(self, study_dict: Dict) -> None: + def clean_and_build_all(self, study_dict: Dict, export_dir: PosixPath) -> None: """Builds views for all studies. NOTE: By design, this method will always exclude the `template` study dir, @@ -129,7 +167,7 @@ def clean_and_build_all(self, study_dict: Dict) -> None: study_dict = dict(study_dict) study_dict.pop("template") for precursor_study in ["vocab", "core"]: - self.clean_and_build_study(study_dict[precursor_study]) + self.clean_and_build_study(study_dict[precursor_study], export_dir) study_dict.pop(precursor_study) for key in study_dict: self.clean_and_build_study(study_dict[key]) @@ -142,7 +180,7 @@ def export_study(self, target: PosixPath, data_path: PosixPath) -> None: """ if data_path is None: sys.exit("Missing destination - please provide a path argument.") - studyparser = StudyManifestParser(target) + studyparser = StudyManifestParser(target, data_path) studyparser.export_study(self.db, data_path) def export_all(self, study_dict: Dict, data_path: PosixPath): @@ -229,7 +267,7 @@ def run_cli(args: Dict): # all other actions require connecting to AWS else: db_backend = create_db_backend(args) - builder = StudyBuilder(db_backend) + builder = StudyBuilder(db_backend, args["data_path"]) if args["verbose"]: builder.verbose = True print("Testing connection to database...") @@ -250,20 +288,26 @@ def run_cli(args: Dict): builder.clean_study( args["target"], study_dict, + args["stats_clean"], args["prefix"], ) elif args["action"] == "build": if "all" in args["target"]: - builder.clean_and_build_all(study_dict) + builder.clean_and_build_all( + study_dict, args["export_dir"], args["stats_build"] + ) else: for target in args["target"]: if args["builder"]: builder.run_single_table_builder( - study_dict[target], args["builder"] + study_dict[target], args["builder"], args["stats_build"] ) else: builder.clean_and_build_study( - study_dict[target], continue_from=args["continue_from"] + study_dict[target], + args["data_path"], + args["stats_build"], + continue_from=args["continue_from"], ) elif args["action"] == "export": @@ -273,6 +317,13 @@ def run_cli(args: Dict): for target in args["target"]: builder.export_study(study_dict[target], args["data_path"]) + # print(set(builder.cursor.execute("""SELECT table_name + # FROM information_schema.tables + # where table_name ilike '%_lib_%' + # or table_name ilike '%_psm_%'""").fetchall())) + # print(builder.cursor.execute("select * from psm_test__lib_statistics").fetchall()) + # print(builder.cursor.execute("select * from psm_test__lib_transactions").fetchall()) + # print(builder.cursor.execute("select * from psm_test__psm_encounter_covariate").fetchall()) db_backend.close() # returning the builder for ease of unit testing return builder @@ -337,6 +388,7 @@ def main(cli_args=None): if args.get("data_path"): args["data_path"] = get_abs_posix_path(args["data_path"]) + return run_cli(args) diff --git a/cumulus_library/cli_parser.py b/cumulus_library/cli_parser.py index 5f951988..fd5c6805 100644 --- a/cumulus_library/cli_parser.py +++ b/cumulus_library/cli_parser.py @@ -1,6 +1,8 @@ """Manages configuration for argparse""" import argparse +from cumulus_library.errors import CumulusLibraryError + def add_target_argument(parser: argparse.ArgumentParser) -> None: """Adds --target arg to a subparser""" @@ -122,6 +124,8 @@ def get_parser() -> argparse.ArgumentParser: dest="action", ) + # Study creation + create = actions.add_parser( "create", help="Create a study instance from a template" ) @@ -135,6 +139,8 @@ def get_parser() -> argparse.ArgumentParser: ), ) + # Database cleaning + clean = actions.add_parser( "clean", help="Removes tables & views beginning with '[target]__' from Athena" ) @@ -143,12 +149,20 @@ def get_parser() -> argparse.ArgumentParser: add_study_dir_argument(clean) add_verbose_argument(clean) add_db_config(clean) + clean.add_argument( + "--statistics", + action="store_true", + help="Remove artifacts of previous statistics runs", + dest="stats_clean", + ) clean.add_argument( "--prefix", action="store_true", help=argparse.SUPPRESS, ) + # Database building + build = actions.add_parser( "build", help="Removes and recreates Athena tables & views for specified studies", @@ -158,7 +172,15 @@ def get_parser() -> argparse.ArgumentParser: add_study_dir_argument(build) add_verbose_argument(build) add_db_config(build) - + build.add_argument( + "--statistics", + action="store_true", + help=( + "Force regenerating statistics data from latest dataset. " + "Stats are created by default when study is initially run" + ), + dest="stats_build", + ) build.add_argument( "--load-ndjson-dir", help="Load ndjson files from this folder", metavar="DIR" ) @@ -168,6 +190,8 @@ def get_parser() -> argparse.ArgumentParser: help=argparse.SUPPRESS, ) + # Database export + export = actions.add_parser( "export", help="Generates files on disk from Athena views" ) @@ -177,6 +201,8 @@ def get_parser() -> argparse.ArgumentParser: add_verbose_argument(export) add_db_config(export) + # Aggregator upload + upload = actions.add_parser( "upload", help="Bulk uploads data to Cumulus aggregator" ) diff --git a/cumulus_library/databases.py b/cumulus_library/databases.py index c3c728d2..afc1b089 100644 --- a/cumulus_library/databases.py +++ b/cumulus_library/databases.py @@ -264,10 +264,10 @@ def create_db_backend(args: dict[str, str]) -> DatabaseBackend: args["profile"], database, ) - if load_ndjson_dir: - sys.exit( - "Loading an ndjson dir is not supported with --db-type=athena (try duckdb)" - ) + # if load_ndjson_dir: + # sys.exit( + # "Loading an ndjson dir is not supported with --db-type=athena (try duckdb)" + # ) else: raise ValueError(f"Unexpected --db-type value '{db_type}'") diff --git a/cumulus_library/enums.py b/cumulus_library/enums.py new file mode 100644 index 00000000..d65af4e4 --- /dev/null +++ b/cumulus_library/enums.py @@ -0,0 +1,17 @@ +""" Holds enums used across more than one module """ +from enum import Enum + + +class PROTECTED_TABLE_KEYWORDS(Enum): + """Tables with a pattern like '_{keyword}_' are not manually dropped.""" + + ETL = "etl" + LIB = "lib" + NLP = "nlp" + + +class PROTECTED_TABLES(Enum): + """Tables created by cumulus for persistence outside of study rebuilds""" + + STATISTICS = "lib_statistics" + TRANSACTIONS = "lib_transactions" diff --git a/cumulus_library/protected_table_builder.py b/cumulus_library/protected_table_builder.py new file mode 100644 index 00000000..6953d04a --- /dev/null +++ b/cumulus_library/protected_table_builder.py @@ -0,0 +1,58 @@ +""" Builder for creating tables for tracking state/logging changes""" +import datetime + +from cumulus_library.base_table_builder import BaseTableBuilder +from cumulus_library.enums import PROTECTED_TABLES +from cumulus_library.template_sql.templates import ( + get_ctas_empty_query, + get_create_view_query, +) + +TRANSACTIONS_COLS = ["study_name", "library_version", "status", "event_time"] +STATISTICS_COLS = [ + "study_name", + "library_version", + "table_type", + "table_name", + "view_name", + "created_on", +] + + +class ProtectedTableBuilder(BaseTableBuilder): + display_text = "Creating/updating system tables..." + + def prepare_queries(self, cursor: object, schema: str, study_name: str): + safe_timestamp = ( + datetime.datetime.now() + .replace(microsecond=0) + .isoformat() + .replace(":", "_") + .replace("-", "_") + ) + self.queries.append( + get_ctas_empty_query( + schema, + f"{study_name}__{PROTECTED_TABLES.TRANSACTIONS.value}", + # while it may seem redundant, study name is included for ease + # of constructing a view of multiple transaction tables + TRANSACTIONS_COLS, + ["varchar", "varchar", "varchar", "timestamp"], + ) + ) + self.queries.append( + get_ctas_empty_query( + schema, + f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}", + # same redundancy note about study_name, and also view_name, applies here + STATISTICS_COLS, + [ + "varchar", + "varchar", + "varchar", + "varchar", + "varchar", + "timestamp", + ], + ) + ) diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index da18f13e..6d1b38cc 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -1,18 +1,22 @@ # Module for generating Propensity Score matching cohorts -import numpy as np -import pandas +import json +import os import sys + +from pathlib import PosixPath +from dataclasses import dataclass + +import pandas import toml from psmpy import PsmPy +# these imports are mimicing PsmPy imports for re-implemented functions +from psmpy.functions import cohenD +import matplotlib.pyplot as plt +import seaborn as sns -import json -from pathlib import PosixPath -from dataclasses import dataclass - -from cumulus_library.cli import StudyBuilder from cumulus_library.databases import DatabaseCursor from cumulus_library.base_table_builder import BaseTableBuilder from cumulus_library.template_sql.templates import ( @@ -53,11 +57,12 @@ class PsmBuilder(BaseTableBuilder): display_text = "Building PSM tables..." - def __init__(self, toml_config_path: str): + def __init__(self, toml_config_path: str, export_path: PosixPath): """Loads PSM job details from a PSM configuration file""" super().__init__() # We're stashing the toml path for error reporting later self.toml_path = toml_config_path + self.export_path = export_path try: with open(self.toml_path, encoding="UTF-8") as file: toml_config = toml.load(file) @@ -132,7 +137,9 @@ def _get_sampled_ids( df[dependent_variable] = is_positive return df - def _create_covariate_table(self, cursor: DatabaseCursor, schema: str): + def _create_covariate_table( + self, cursor: DatabaseCursor, schema: str, table_suffix: str + ): """Creates a covariate table from the loaded toml config""" # checks for primary & link ref being the same source_refs = list({self.config.primary_ref, self.config.count_ref} - {None}) @@ -163,23 +170,24 @@ def _create_covariate_table(self, cursor: DatabaseCursor, schema: str): # Replace table (if it exists) # TODO - replace with timestamp prepended table in future PR - drop = get_drop_view_table( - f"{self.config.pos_source_table}_sampled_ids", "TABLE" - ) - cursor.execute(drop) + # drop = get_drop_view_table( + # f"{self.config.pos_source_table}_sampled_ids", "TABLE" + # ) + # cursor.execute(drop) ctas_query = get_ctas_query_from_df( schema, - f"{self.config.pos_source_table}_sampled_ids", + f"{self.config.pos_source_table}_sampled_ids_{table_suffix}", cohort, ) self.queries.append(ctas_query) # TODO - replace with timestamp prepended table - drop = get_drop_view_table(self.config.target_table, "TABLE") - cursor.execute(drop) + # drop = get_drop_view_table(self.config.target_table, "TABLE") + # cursor.execute(drop) dataset_query = get_create_covariate_table( - target_table=self.config.target_table, + target_table=f"{self.config.target_table}_{table_suffix}", pos_source_table=self.config.pos_source_table, neg_source_table=self.config.neg_source_table, + table_suffix=table_suffix, primary_ref=self.config.primary_ref, dependent_variable=self.config.dependent_variable, join_cols_by_table=self.config.join_cols_by_table, @@ -188,9 +196,90 @@ def _create_covariate_table(self, cursor: DatabaseCursor, schema: str): ) self.queries.append(dataset_query) - def generate_psm_analysis(self, cursor: DatabaseCursor, schema: str): + def psm_plot_match( + self, + psm, + matched_entity="propensity_logit", + Title="Side by side matched controls", + Ylabel="Number of patients", + Xlabel="propensity logit", + names=["positive_cohort", "negative_cohort"], + colors=["#E69F00", "#56B4E9"], + save=True, + filename="propensity_match.png", + ): + """Plots knn match data + + This function re-implements psm.plot_match, with the only changes + allowing for specifiying a filename/location for saving plots to, + and passing in the psm object instead of assuming a call from inside + the PsmPy class. + """ + dftreat = psm.df_matched[psm.df_matched[psm.treatment] == 1] + dfcontrol = psm.df_matched[psm.df_matched[psm.treatment] == 0] + x1 = dftreat[matched_entity] + x2 = dfcontrol[matched_entity] + colors = colors + names = names + sns.set_style("white") + plt.hist([x1, x2], color=colors, label=names) + plt.legend() + plt.xlabel(Xlabel) + plt.ylabel(Ylabel) + plt.title(Title) + if save == True: + plt.savefig(filename, dpi=250) + + def psm_effect_size_plot( + self, + psm, + title="Standardized Mean differences accross covariates before and after matching", + before_color="#FCB754", + after_color="#3EC8FB", + save=False, + filename="effect_size.png", + ): + """Plots effect size of variables for positive/negative matches + + This function re-implements psm.effect_size_plot, with the only changes + allowing for specifiying a filename/location for saving plots to, + and passing in the psm object instead of assuming a call from inside + the PsmPy class. + """ + df_preds_after = psm.df_matched[[psm.treatment] + psm.xvars] + df_preds_b4 = psm.data[[psm.treatment] + psm.xvars] + df_preds_after_float = df_preds_after.astype(float) + df_preds_b4_float = df_preds_b4.astype(float) + + data = [] + for cl in psm.xvars: + data.append([cl, "before", cohenD(df_preds_b4_float, psm.treatment, cl)]) + data.append([cl, "after", cohenD(df_preds_after_float, psm.treatment, cl)]) + psm.effect_size = pandas.DataFrame( + data, columns=["Variable", "matching", "Effect Size"] + ) + sns.set_style("white") + sns_plot = sns.barplot( + data=psm.effect_size, + y="Variable", + x="Effect Size", + hue="matching", + palette=[before_color, after_color], + orient="h", + ) + sns_plot.set(title=title) + if save == True: + sns_plot.figure.savefig(filename, dpi=250, bbox_inches="tight") + + def generate_psm_analysis( + self, cursor: DatabaseCursor, schema: str, table_suffix: str + ): """Runs PSM statistics on generated tables""" - df = cursor.execute(f"select * from {self.config.target_table}").as_pandas() + cursor.execute( + f"""CREATE OR REPLACE VIEW {self.config.target_table} + AS SELECT * FROM {self.config.target_table}_{table_suffix}""" + ) + df = cursor.execute(f"SELECT * FROM {self.config.target_table}").as_pandas() symptoms_dict = self._get_symptoms_dict(self.config.classification_json) for dependent_variable, codes in symptoms_dict.items(): df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0) @@ -247,7 +336,19 @@ def generate_psm_analysis(self, cursor: DatabaseCursor, schema: str): caliper=None, drop_unmatched=True, ) - + os.makedirs(self.export_path, exist_ok=True) + self.psm_plot_match( + psm, + save=True, + filename=self.export_path + / f"{self.config.target_table}_{table_suffix}_propensity_match.png", + ) + self.psm_effect_size_plot( + psm, + save=True, + filename=self.export_path + / f"{self.config.target_table}_{table_suffix}_effect_size.png", + ) except ZeroDivisionError: sys.exit( "Encountered a divide by zero error during statistical graph generation. Try increasing your sample size." @@ -257,8 +358,8 @@ def generate_psm_analysis(self, cursor: DatabaseCursor, schema: str): "Encountered a value error during KNN matching. Try increasing your sample size." ) - def prepare_queries(self, cursor: object, schema: str): - self._create_covariate_table(cursor, schema) + def prepare_queries(self, cursor: object, schema: str, table_suffix: str): + self._create_covariate_table(cursor, schema, table_suffix) def post_execution( self, @@ -266,6 +367,7 @@ def post_execution( schema: str, verbose: bool, drop_table: bool = False, + table_suffix: str = None, ): # super().execute_queries(cursor, schema, verbose, drop_table) - self.generate_psm_analysis(cursor, schema) + self.generate_psm_analysis(cursor, schema, table_suffix) diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index e2e0d112..27049dcc 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -3,6 +3,7 @@ import importlib.util import sys +from datetime import datetime from pathlib import Path, PosixPath from typing import List, Optional @@ -10,8 +11,10 @@ from rich.progress import Progress, TaskID, track +from cumulus_library import __version__ from cumulus_library.base_table_builder import BaseTableBuilder from cumulus_library.databases import DatabaseBackend, DatabaseCursor +from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS, PROTECTED_TABLES from cumulus_library.errors import StudyManifestParsingError from cumulus_library.helper import ( query_console_output, @@ -19,16 +22,17 @@ parse_sql, get_progress_bar, ) +from cumulus_library.protected_table_builder import ProtectedTableBuilder +from cumulus_library.statistics.psm import PsmBuilder from cumulus_library.template_sql.templates import ( get_show_tables, get_show_views, get_drop_view_table, + get_insert_into_query, ) StrList = List[str] -RESERVED_TABLE_KEYWORDS = ["etl", "nlp", "lib"] - class StudyManifestParser: """Handles loading of study data from manifest files. @@ -38,21 +42,22 @@ class StudyManifestParser: mechanisms for IDing studies/files of interest, and for executing queries, but specifically it should never be in charge of instantiation a cursor itself - this will help to future proof against other database implementations in the - future, assuming those DBs have a PEP-249 cursor available (and this is why we - are hinting generic objects for cursors). - + future. """ _study_path = None _study_config = {} - def __init__(self, study_path: Optional[Path] = None): + def __init__( + self, study_path: Optional[Path] = None, data_path: Optional[Path] = None + ): """Instantiates a StudyManifestParser. :param study_path: A pathlib Path object, optional """ if study_path is not None: self.load_study_manifest(study_path) + self.data_path = data_path def __repr__(self): return str(self._study_config) @@ -119,6 +124,14 @@ def get_counts_builder_file_list(self) -> Optional[StrList]: sql_config = self._study_config.get("counts_builder_config", {}) return sql_config.get("file_names", []) + def get_statistics_file_list(self) -> Optional[StrList]: + """Reads the contents of the statistics_config array from the manifest + + :returns: An array of statistics toml files from the manifest, or None if not found. + """ + stats_config = self._study_config.get("statistics_config", {}) + return stats_config.get("file_names", []) + def get_export_table_list(self) -> Optional[StrList]: """Reads the contents of the export_list array from the manifest @@ -134,14 +147,14 @@ def get_export_table_list(self) -> Optional[StrList]: ) return export_table_list - def reset_export_dir(self) -> None: + def reset_data_dir(self) -> None: """ Removes exports associated with this study from the ../data_export directory. """ project_path = Path(__file__).resolve().parents[1] - path = Path(f"{str(project_path)}/data_export/{self.get_study_prefix()}/") + path = self.data_path / self.get_study_prefix() if path.exists(): - for file in path.glob("*"): + for file in path.glob("*.*"): file.unlink() # SQL related functions @@ -149,6 +162,7 @@ def clean_study( self, cursor: DatabaseCursor, schema_name: str, + stats_clean: bool = False, verbose: bool = False, prefix: str = None, ) -> List: @@ -168,12 +182,32 @@ def clean_study( else: drop_prefix = prefix display_prefix = drop_prefix + + if stats_clean: + confirm = input( + "This will remove all historical stats tables beginning in the " + f"{display_prefix} study - are you sure? (y/N)" + ) + if confirm.lower() not in ("y", "yes"): + sys.exit("Table cleaning aborted") + view_sql = get_show_views(schema_name, drop_prefix) table_sql = get_show_tables(schema_name, drop_prefix) view_table_list = [] - for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: - cursor.execute(query_and_type[0]) - for db_row_tuple in cursor.fetchall(): + for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: # + tuple_list = cursor.execute(query_and_type[0]).fetchall() + if ( + f"{drop_prefix}{PROTECTED_TABLES.STATISTICS.value}", + ) in tuple_list and not stats_clean: + protected_list = cursor.execute( + f"""SELECT {(query_and_type[1]).lower()}_name + FROM {drop_prefix}{PROTECTED_TABLES.STATISTICS.value} + WHERE study_name = '{display_prefix}'""" + ).fetchall() + for protected_tuple in protected_list: + if protected_tuple in tuple_list: + tuple_list.remove(protected_tuple) + for db_row_tuple in tuple_list: # this check handles athena reporting views as also being tables, # so we don't waste time dropping things that don't exist if query_and_type[1] == "TABLE": @@ -191,8 +225,11 @@ def clean_study( # study builder, and remove them from the list. for view_table in view_table_list.copy(): if any( - ((f"_{word}_") in view_table[0] or view_table[0].endswith(word)) - for word in RESERVED_TABLE_KEYWORDS + ( + (f"_{word.value}_") in view_table[0] + or view_table[0].endswith(word.value) + ) + for word in PROTECTED_TABLE_KEYWORDS ): view_table_list.remove(view_table) # We want to only show a progress bar if we are :not: printing SQL lines @@ -216,6 +253,12 @@ def clean_study( progress, task, ) + if stats_clean: + drop_query = get_drop_view_table( + f"{drop_prefix}{PROTECTED_TABLES.STATISTICS.value}", "TABLE" + ) + cursor.execute(drop_query) + return view_table_list def _execute_drop_queries( @@ -289,12 +332,26 @@ def _load_and_execute_builder( table_builder = table_builder_class() table_builder.execute_queries(cursor, schema, verbose, drop_table) - # After runnning the executor code, we'll remove - # remove it so it doesn't interfere with the next python module to + # After running the executor code, we'll remove + # it so it doesn't interfere with the next python module to # execute, since the subclass would otherwise hang around. del sys.modules[table_builder_module.__name__] del table_builder_module + def run_protected_table_builder( + self, cursor: DatabaseCursor, schema: str, verbose: bool = False + ) -> None: + """Creates protected tables for persisting selected data across runs + + :param cursor: A PEP-249 compatible cursor object + :param schema: The name of the schema to write tables to + :param verbose: toggle from progress bar to query output + """ + ptb = ProtectedTableBuilder() + ptb.execute_queries( + cursor, schema, verbose, study_name=self._study_config.get("study_prefix") + ) + def run_table_builder( self, cursor: DatabaseCursor, schema: str, verbose: bool = False ) -> None: @@ -307,11 +364,16 @@ def run_table_builder( for file in self.get_table_builder_file_list(): self._load_and_execute_builder(file, cursor, schema, verbose) - def run_counts_builder( + def run_counts_builders( self, cursor: DatabaseCursor, schema: str, verbose: bool = False ) -> None: """Loads counts modules from a manifest and executes code via BaseTableBuilder + While a count is a form of statistics, it is treated separately from other + statistics because it is, by design, always going to be static against a + given dataset, where other statistical methods may use sampling techniques + or adjustable input parameters that may need to be preserved for later review. + :param cursor: A PEP-249 compatible cursor object :param schema: The name of the schema to write tables to :param verbose: toggle from progress bar to query output @@ -319,6 +381,69 @@ def run_counts_builder( for file in self.get_counts_builder_file_list(): self._load_and_execute_builder(file, cursor, schema, verbose) + def run_statistics_builders( + self, + cursor: DatabaseCursor, + schema: str, + verbose: bool = False, + stats_build: bool = False, + data_path: PosixPath = None, + ) -> None: + """Loads statistics modules from toml definitions and executes + + :param cursor: A PEP-249 compatible cursor object + :param schema: The name of the schema to write tables to + :param verbose: toggle from progress bar to query output + """ + if not stats_build: + return + for file in self.get_statistics_file_list(): + # This open is a bit redundant with the open inside of the PSM builder, + # but we're letting it slide so that builders function similarly + # across the board + iso_timestamp = datetime.now().replace(microsecond=0).isoformat() + safe_timestamp = iso_timestamp.replace(":", "_").replace("-", "_") + toml_path = Path(f"{self._study_path}/{file}") + with open(toml_path, encoding="UTF-8") as file: + config = toml.load(file) + config_type = config["config_type"] + target_table = config["target_table"] + if config_type == "psm": + builder = PsmBuilder( + toml_path, self.data_path / f"{self.get_study_prefix()}/psm" + ) + else: + raise StudyManifestParsingError( + f"{toml_path} references an invalid statistics type {config_type}." + ) + builder.execute_queries( + cursor, schema, verbose, table_suffix=safe_timestamp + ) + + insert_query = get_insert_into_query( + f"{self.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + [ + "study_name", + "library_version", + "table_type", + "table_name", + "view_name", + "created_on", + ], + [ + [ + self.get_study_prefix(), + __version__, + config_type, + f"{target_table}_{safe_timestamp}", + target_table, + iso_timestamp, + ] + ], + ) + cursor.execute(insert_query) + # self._load_and_execute_builder(file, cursor, schema, verbose) + def run_single_table_builder( self, cursor: DatabaseCursor, schema: str, name: str, verbose: bool = False ): @@ -395,8 +520,8 @@ def _execute_build_queries( "should be in the first line of the query.", ) if any( - f" {self.get_study_prefix()}__{word}_" in create_line - for word in RESERVED_TABLE_KEYWORDS + f" {self.get_study_prefix()}__{word.value}_" in create_line + for word in PROTECTED_TABLE_KEYWORDS ): self._query_error( query, @@ -404,7 +529,7 @@ def _execute_build_queries( "immediately after the study prefix. Please rename this table so " "that is does not begin with one of these special words " "immediately after the double undescore.\n" - f"Reserved words: {str(RESERVED_TABLE_KEYWORDS)}", + f"Reserved words: {str(word.value for word in PROTECTED_TABLE_KEYWORDS)}", ) if create_line.count("__") > 1: self._query_error( @@ -439,7 +564,7 @@ def export_study(self, db: DatabaseBackend, data_path: PosixPath) -> List: :param db: A database backend :returns: list of executed queries (for unit testing only) """ - self.reset_export_dir() + self.reset_data_dir() queries = [] for table in track( self.get_export_table_list(), diff --git a/cumulus_library/template_sql/ctas_empty.sql.jinja b/cumulus_library/template_sql/ctas_empty.sql.jinja index f65a1a78..c13f1df4 100644 --- a/cumulus_library/template_sql/ctas_empty.sql.jinja +++ b/cumulus_library/template_sql/ctas_empty.sql.jinja @@ -1,9 +1,10 @@ -CREATE TABLE "{{ schema_name }}"."{{ table_name }}" AS ( +CREATE TABLE IF NOT EXISTS "{{ schema_name }}"."{{ table_name }}" +AS ( SELECT * FROM ( VALUES ( - {%- for col in table_cols -%} - cast(NULL AS varchar) + {%- for type in table_cols_types -%} + cast(NULL AS {{ type }}) {%- if not loop.last -%} , {%- endif -%} diff --git a/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja b/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja index 8e5d2558..2ebde5bc 100644 --- a/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja +++ b/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja @@ -33,7 +33,7 @@ CREATE TABLE {{ target_table }} AS ( {%- endif -%} {{ select_column_or_alias(join_cols_by_table) }} {{ neg_source_table }}.code - FROM "{{ pos_source_table }}_sampled_ids" AS sample_cohort, + FROM "{{ pos_source_table }}_sampled_ids_{{table_suffix}}" AS sample_cohort, "{{ neg_source_table }}", {%- for key in join_cols_by_table %} "{{ key }}" diff --git a/cumulus_library/template_sql/statistics/psm_templates.py b/cumulus_library/template_sql/statistics/psm_templates.py index 71082e01..928c82e7 100644 --- a/cumulus_library/template_sql/statistics/psm_templates.py +++ b/cumulus_library/template_sql/statistics/psm_templates.py @@ -44,6 +44,7 @@ def get_create_covariate_table( target_table: str, pos_source_table: str, neg_source_table: str, + table_suffix: str, primary_ref: str, dependent_variable: str, join_cols_by_table: dict, @@ -75,6 +76,7 @@ def get_create_covariate_table( target_table=target_table, pos_source_table=pos_source_table, neg_source_table=neg_source_table, + table_suffix=table_suffix, primary_ref=primary_ref, dependent_variable=dependent_variable, count_ref=count_ref, diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index cbacece1..330754e0 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -203,7 +203,10 @@ def get_ctas_query_from_df(schema_name: str, table_name: str, df: DataFrame) -> def get_ctas_empty_query( - schema_name: str, table_name: str, table_cols: List[str] + schema_name: str, + table_name: str, + table_cols: List[str], + table_cols_types: List[str] = [], ) -> str: """Generates a create table as query for initializing an empty table @@ -215,13 +218,18 @@ def get_ctas_empty_query( :param schema_name: The athena schema to create the table in :param table_name: The name of the athena table to create :param table_cols: Comma deleniated column names, i.e. ['first,second'] + :param table_cols: Allows specifying a data type per column (default: all varchar) """ path = Path(__file__).parent + if table_cols_types == []: + for col in table_cols: + table_cols_types.append("varchar") with open(f"{path}/ctas_empty.sql.jinja") as ctas_empty: return Template(ctas_empty.read()).render( schema_name=schema_name, table_name=table_name, table_cols=table_cols, + table_cols_types=table_cols_types, ) diff --git a/docs/statistics/propensity-score-matching.md b/docs/statistics/propensity-score-matching.md index a002e3a4..aebbd322 100644 --- a/docs/statistics/propensity-score-matching.md +++ b/docs/statistics/propensity-score-matching.md @@ -51,6 +51,10 @@ details on the expectations of each value. # database. We recommend that you only attempt to use this after you have decided # on the first draft of your cohort selection criteria +# config_type should always be "psm" - we use this to distinguish from other +# statistic type runs +config_type = "psm" + # classification_json should reference a file in the same directory as this config, # which matches a category to a set of ICD codes. As an example, you could use # an existing guide like DSM5 classifications for this, but you could also use diff --git a/tests/test_data/duckdb_data/duck.db b/tests/test_data/duckdb_data/duck.db deleted file mode 100644 index ababe675cb4d91b04d53fbe3248702064e5d6695..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 12288 zcmeI#F$#lF3;<9UkKi>vqO|32Lc6&MoyFEOc!_Skk5kX#;zin7L1>qP@Un#j5?-b> zQ~MZS&-rpa*H!Xe40(v5*zPUw$9U*Zw=Qn?s1P7PfB*pk1PBlyK!5-N0{;_8v*oP! xDwb6l0RjXF5FkK+009C72oNB!M*;u;9 Date: Wed, 20 Dec 2023 13:28:46 -0500 Subject: [PATCH 02/13] existing unit test rework, data generator --- cumulus_library/.sqlfluff | 2 + cumulus_library/base_table_builder.py | 2 - cumulus_library/cli.py | 24 +--- cumulus_library/cli_parser.py | 3 +- cumulus_library/databases.py | 9 +- cumulus_library/protected_table_builder.py | 43 +++--- cumulus_library/statistics/psm.py | 75 +++++----- cumulus_library/study_parser.py | 25 ++-- .../template_sql/ctas_empty.sql.jinja | 2 +- .../psm_create_covariate_table.sql.jinja | 2 +- cumulus_library/template_sql/templates.py | 2 +- pyproject.toml | 4 +- tests/conftest.py | 132 +++++++++++++++-- tests/test_cli.py | 11 +- tests/test_conftest.py | 45 ++++++ tests/test_core.py | 8 +- tests/test_data/psm/psm_config.toml | 4 +- .../test_data/psm/psm_config_no_optional.toml | 7 +- tests/test_psm.py | 77 ++++++---- tests/test_psm_templates.py | 11 +- tests/test_study_parser.py | 135 +++++++++++------- 21 files changed, 412 insertions(+), 211 deletions(-) create mode 100644 tests/test_conftest.py diff --git a/cumulus_library/.sqlfluff b/cumulus_library/.sqlfluff index 536cb290..d755da80 100644 --- a/cumulus_library/.sqlfluff +++ b/cumulus_library/.sqlfluff @@ -46,7 +46,9 @@ schema_name = test_schema source_table = source_table source_id = source_id table_cols = ["a","b"] +table_cols_types = ["varchar", "varchar"] table_name = test_table +table_suffix = 2024_01_01_11_11_11 target_col_prefix = prefix target_table = target_table unnests = [{"source col": "g", "table_alias": "i", "row_alias":"j"}, {"source col": "k", "table_alias": "l", "row_alias":"m"},] diff --git a/cumulus_library/base_table_builder.py b/cumulus_library/base_table_builder.py index dfcd41fe..2e04bd47 100644 --- a/cumulus_library/base_table_builder.py +++ b/cumulus_library/base_table_builder.py @@ -30,8 +30,6 @@ def prepare_queries(self, cursor: object, schema: str, *args, **kwargs): :param cursor: A PEP-249 compatible cursor :param schema: A schema name - :param db_type: The db system being used (only relevant for db-specific - query construction) """ raise NotImplementedError diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index b8d04de0..ffafb762 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -98,7 +98,6 @@ def clean_study( def clean_and_build_study( self, target: PosixPath, - export_dir: PosixPath, stats_build: bool, continue_from: str = None, ) -> None: @@ -106,7 +105,6 @@ def clean_and_build_study( :param target: A PosixPath to the study directory """ - studyparser = StudyManifestParser(target, self.data_path) try: if not continue_from: @@ -140,6 +138,8 @@ def clean_and_build_study( data_path=self.data_path, ) self.update_transactions(studyparser.get_study_prefix(), "finished") + except SystemExit as exit: + raise exit except Exception as e: self.update_transactions(studyparser.get_study_prefix(), "error") raise e @@ -156,7 +156,7 @@ def run_single_table_builder( self.cursor, self.schema_name, table_builder_name, self.verbose ) - def clean_and_build_all(self, study_dict: Dict, export_dir: PosixPath) -> None: + def clean_and_build_all(self, study_dict: Dict, stats_build: bool) -> None: """Builds views for all studies. NOTE: By design, this method will always exclude the `template` study dir, @@ -167,10 +167,10 @@ def clean_and_build_all(self, study_dict: Dict, export_dir: PosixPath) -> None: study_dict = dict(study_dict) study_dict.pop("template") for precursor_study in ["vocab", "core"]: - self.clean_and_build_study(study_dict[precursor_study], export_dir) + self.clean_and_build_study(study_dict[precursor_study], stats_build) study_dict.pop(precursor_study) for key in study_dict: - self.clean_and_build_study(study_dict[key]) + self.clean_and_build_study(study_dict[key], stats_build) ### Data exporters def export_study(self, target: PosixPath, data_path: PosixPath) -> None: @@ -267,7 +267,7 @@ def run_cli(args: Dict): # all other actions require connecting to AWS else: db_backend = create_db_backend(args) - builder = StudyBuilder(db_backend, args["data_path"]) + builder = StudyBuilder(db_backend, data_path=args.get("data_path", None)) if args["verbose"]: builder.verbose = True print("Testing connection to database...") @@ -293,9 +293,7 @@ def run_cli(args: Dict): ) elif args["action"] == "build": if "all" in args["target"]: - builder.clean_and_build_all( - study_dict, args["export_dir"], args["stats_build"] - ) + builder.clean_and_build_all(study_dict, args["stats_build"]) else: for target in args["target"]: if args["builder"]: @@ -305,7 +303,6 @@ def run_cli(args: Dict): else: builder.clean_and_build_study( study_dict[target], - args["data_path"], args["stats_build"], continue_from=args["continue_from"], ) @@ -317,13 +314,6 @@ def run_cli(args: Dict): for target in args["target"]: builder.export_study(study_dict[target], args["data_path"]) - # print(set(builder.cursor.execute("""SELECT table_name - # FROM information_schema.tables - # where table_name ilike '%_lib_%' - # or table_name ilike '%_psm_%'""").fetchall())) - # print(builder.cursor.execute("select * from psm_test__lib_statistics").fetchall()) - # print(builder.cursor.execute("select * from psm_test__lib_transactions").fetchall()) - # print(builder.cursor.execute("select * from psm_test__psm_encounter_covariate").fetchall()) db_backend.close() # returning the builder for ease of unit testing return builder diff --git a/cumulus_library/cli_parser.py b/cumulus_library/cli_parser.py index fd5c6805..b5f15bd3 100644 --- a/cumulus_library/cli_parser.py +++ b/cumulus_library/cli_parser.py @@ -1,8 +1,6 @@ """Manages configuration for argparse""" import argparse -from cumulus_library.errors import CumulusLibraryError - def add_target_argument(parser: argparse.ArgumentParser) -> None: """Adds --target arg to a subparser""" @@ -171,6 +169,7 @@ def get_parser() -> argparse.ArgumentParser: add_table_builder_argument(build) add_study_dir_argument(build) add_verbose_argument(build) + add_data_path_argument(build) add_db_config(build) build.add_argument( "--statistics", diff --git a/cumulus_library/databases.py b/cumulus_library/databases.py index afc1b089..3857151e 100644 --- a/cumulus_library/databases.py +++ b/cumulus_library/databases.py @@ -13,6 +13,7 @@ import json import os import sys +import warnings from pathlib import Path from typing import Optional, Protocol, Union @@ -264,10 +265,10 @@ def create_db_backend(args: dict[str, str]) -> DatabaseBackend: args["profile"], database, ) - # if load_ndjson_dir: - # sys.exit( - # "Loading an ndjson dir is not supported with --db-type=athena (try duckdb)" - # ) + if load_ndjson_dir: + warnings.warn( + "Loading an ndjson dir is not supported with --db-type=athena." + ) else: raise ValueError(f"Unexpected --db-type value '{db_type}'") diff --git a/cumulus_library/protected_table_builder.py b/cumulus_library/protected_table_builder.py index 6953d04a..b8deb85b 100644 --- a/cumulus_library/protected_table_builder.py +++ b/cumulus_library/protected_table_builder.py @@ -22,14 +22,10 @@ class ProtectedTableBuilder(BaseTableBuilder): display_text = "Creating/updating system tables..." - def prepare_queries(self, cursor: object, schema: str, study_name: str): - safe_timestamp = ( - datetime.datetime.now() - .replace(microsecond=0) - .isoformat() - .replace(":", "_") - .replace("-", "_") - ) + def prepare_queries( + self, cursor: object, schema: str, study_name: str, study_stats: dict + ): + print("hi") self.queries.append( get_ctas_empty_query( schema, @@ -40,19 +36,20 @@ def prepare_queries(self, cursor: object, schema: str, study_name: str): ["varchar", "varchar", "varchar", "timestamp"], ) ) - self.queries.append( - get_ctas_empty_query( - schema, - f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}", - # same redundancy note about study_name, and also view_name, applies here - STATISTICS_COLS, - [ - "varchar", - "varchar", - "varchar", - "varchar", - "varchar", - "timestamp", - ], + if study_stats: + self.queries.append( + get_ctas_empty_query( + schema, + f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}", + # same redundancy note about study_name, and also view_name, applies here + STATISTICS_COLS, + [ + "varchar", + "varchar", + "varchar", + "varchar", + "varchar", + "timestamp", + ], + ) ) - ) diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index 6d1b38cc..c3a71e06 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -3,6 +3,7 @@ import json import os import sys +import warnings from pathlib import PosixPath from dataclasses import dataclass @@ -57,12 +58,12 @@ class PsmBuilder(BaseTableBuilder): display_text = "Building PSM tables..." - def __init__(self, toml_config_path: str, export_path: PosixPath): + def __init__(self, toml_config_path: str, data_path: PosixPath): """Loads PSM job details from a PSM configuration file""" super().__init__() # We're stashing the toml path for error reporting later self.toml_path = toml_config_path - self.export_path = export_path + self.data_path = data_path try: with open(self.toml_path, encoding="UTF-8") as file: toml_config = toml.load(file) @@ -166,23 +167,15 @@ def _create_covariate_table( self.config.dependent_variable, 0, ) - cohort = pandas.concat([pos, neg]) - # Replace table (if it exists) - # TODO - replace with timestamp prepended table in future PR - # drop = get_drop_view_table( - # f"{self.config.pos_source_table}_sampled_ids", "TABLE" - # ) - # cursor.execute(drop) + cohort = pandas.concat([pos, neg]) ctas_query = get_ctas_query_from_df( schema, f"{self.config.pos_source_table}_sampled_ids_{table_suffix}", cohort, ) self.queries.append(ctas_query) - # TODO - replace with timestamp prepended table - # drop = get_drop_view_table(self.config.target_table, "TABLE") - # cursor.execute(drop) + dataset_query = get_create_covariate_table( target_table=f"{self.config.target_table}_{table_suffix}", pos_source_table=self.config.pos_source_table, @@ -195,6 +188,7 @@ def _create_covariate_table( count_table=self.config.count_table, ) self.queries.append(dataset_query) + print(dataset_query) def psm_plot_match( self, @@ -315,7 +309,6 @@ def generate_psm_analysis( df = pandas.concat([df, encoded_df], axis=1) df = df.drop(column, axis=1) df = df.reset_index() - try: psm = PsmPy( df, @@ -323,32 +316,36 @@ def generate_psm_analysis( indx=self.config.primary_ref, exclude=[], ) - # This function populates the psm.predicted_data element, which is required - # for things like the knn_matched() function call - # TODO: create graph from this data - psm.logistic_ps(balance=True) - # This function populates the psm.df_matched element - # TODO: flip replacement to false after increasing sample data size - # TODO: create graph from this data - psm.knn_matched( - matcher="propensity_logit", - replacement=True, - caliper=None, - drop_unmatched=True, - ) - os.makedirs(self.export_path, exist_ok=True) - self.psm_plot_match( - psm, - save=True, - filename=self.export_path - / f"{self.config.target_table}_{table_suffix}_propensity_match.png", - ) - self.psm_effect_size_plot( - psm, - save=True, - filename=self.export_path - / f"{self.config.target_table}_{table_suffix}_effect_size.png", - ) + + # we expect psmpy to autodrop non-matching values, so we'll surpress it + # mentioning workarounds for this behavior. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + # This function populates the psm.predicted_data element, which is required + # for things like the knn_matched() function call + # TODO: create graph from this data + psm.logistic_ps(balance=True) + # This function populates the psm.df_matched element + # TODO: create graph from this data + psm.knn_matched( + matcher="propensity_logit", + replacement=False, + caliper=None, + drop_unmatched=True, + ) + os.makedirs(self.data_path, exist_ok=True) + self.psm_plot_match( + psm, + save=True, + filename=self.data_path + / f"{self.config.target_table}_{table_suffix}_propensity_match.png", + ) + self.psm_effect_size_plot( + psm, + save=True, + filename=self.data_path + / f"{self.config.target_table}_{table_suffix}_effect_size.png", + ) except ZeroDivisionError: sys.exit( "Encountered a divide by zero error during statistical graph generation. Try increasing your sample size." diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index 27049dcc..7e7bfcd6 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -151,9 +151,12 @@ def reset_data_dir(self) -> None: """ Removes exports associated with this study from the ../data_export directory. """ - project_path = Path(__file__).resolve().parents[1] - path = self.data_path / self.get_study_prefix() + print(self.data_path) + print(type(self.data_path)) + path = Path(f"{self.data_path}/{self.get_study_prefix()}") if path.exists(): + # we're just going to remove the count files - exports related to stats + # that aren't uploaded to the aggregator are left alone. for file in path.glob("*.*"): file.unlink() @@ -188,13 +191,13 @@ def clean_study( "This will remove all historical stats tables beginning in the " f"{display_prefix} study - are you sure? (y/N)" ) - if confirm.lower() not in ("y", "yes"): + if confirm is None or confirm.lower() not in ("y", "yes"): sys.exit("Table cleaning aborted") view_sql = get_show_views(schema_name, drop_prefix) table_sql = get_show_tables(schema_name, drop_prefix) view_table_list = [] - for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: # + for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: tuple_list = cursor.execute(query_and_type[0]).fetchall() if ( f"{drop_prefix}{PROTECTED_TABLES.STATISTICS.value}", @@ -232,7 +235,7 @@ def clean_study( for word in PROTECTED_TABLE_KEYWORDS ): view_table_list.remove(view_table) - # We want to only show a progress bar if we are :not: printing SQL lines + if prefix: print("The following views/tables were selected by prefix:") for view_table in view_table_list: @@ -240,6 +243,7 @@ def clean_study( confirm = input("Remove these tables? (y/N)") if confirm.lower() not in ("y", "yes"): sys.exit("Table cleaning aborted") + # We want to only show a progress bar if we are :not: printing SQL lines with get_progress_bar(disable=verbose) as progress: task = progress.add_task( f"Removing {display_prefix} study artifacts...", @@ -349,7 +353,11 @@ def run_protected_table_builder( """ ptb = ProtectedTableBuilder() ptb.execute_queries( - cursor, schema, verbose, study_name=self._study_config.get("study_prefix") + cursor, + schema, + verbose, + study_name=self._study_config.get("study_prefix"), + study_stats=self._study_config.get("statistics_config"), ) def run_table_builder( @@ -393,7 +401,9 @@ def run_statistics_builders( :param cursor: A PEP-249 compatible cursor object :param schema: The name of the schema to write tables to - :param verbose: toggle from progress bar to query output + :keyword verbose: toggle from progress bar to query output + :keyword stats_build: If true, will run statistical sampling & table generation + :keyword data_path: A path to where stats output artifacts should be stored """ if not stats_build: return @@ -442,7 +452,6 @@ def run_statistics_builders( ], ) cursor.execute(insert_query) - # self._load_and_execute_builder(file, cursor, schema, verbose) def run_single_table_builder( self, cursor: DatabaseCursor, schema: str, name: str, verbose: bool = False diff --git a/cumulus_library/template_sql/ctas_empty.sql.jinja b/cumulus_library/template_sql/ctas_empty.sql.jinja index c13f1df4..980d4f22 100644 --- a/cumulus_library/template_sql/ctas_empty.sql.jinja +++ b/cumulus_library/template_sql/ctas_empty.sql.jinja @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS "{{ schema_name }}"."{{ table_name }}" +CREATE TABLE IF NOT EXISTS "{{ schema_name }}"."{{ table_name }}" AS ( SELECT * FROM ( VALUES diff --git a/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja b/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja index 2ebde5bc..6e71c3dd 100644 --- a/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja +++ b/cumulus_library/template_sql/statistics/psm_create_covariate_table.sql.jinja @@ -33,7 +33,7 @@ CREATE TABLE {{ target_table }} AS ( {%- endif -%} {{ select_column_or_alias(join_cols_by_table) }} {{ neg_source_table }}.code - FROM "{{ pos_source_table }}_sampled_ids_{{table_suffix}}" AS sample_cohort, + FROM "{{ pos_source_table }}_sampled_ids_{{ table_suffix }}" AS sample_cohort, "{{ neg_source_table }}", {%- for key in join_cols_by_table %} "{{ key }}" diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index 330754e0..9c7f7c5b 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -218,7 +218,7 @@ def get_ctas_empty_query( :param schema_name: The athena schema to create the table in :param table_name: The name of the athena table to create :param table_cols: Comma deleniated column names, i.e. ['first,second'] - :param table_cols: Allows specifying a data type per column (default: all varchar) + :param table_cols_types: Allows specifying a data type per column (default: all varchar) """ path = Path(__file__).parent if table_cols_types == []: diff --git a/pyproject.toml b/pyproject.toml index 28a2076c..a092795d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,8 +34,10 @@ dev = [ "pylint", ] test = [ + "freezegun", + "jmespath", "pytest", - "requests-mock" + "requests-mock", ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index a7a396b0..893f961e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,39 @@ """pytest mocks and testing utility classes/methods""" +import copy +import json +import os import tempfile from enum import IntEnum -from pathlib import Path +from pathlib import Path +import pandas import pytest +import numpy from cumulus_library.cli import StudyBuilder from cumulus_library.databases import create_db_backend, DatabaseCursor MOCK_DATA_DIR = f"{Path(__file__).parent}/test_data/duckdb_data" +ID_PATHS = { + "condition": [["id"], ["encounter", "reference"], ["subject", "reference"]], + "documentreference": [ + ["id"], + ["subject", "reference"], + ["context", "encounter", "reference"], + ], + "encounter": [["id"], ["subject", "reference"]], + "medicationrequest": [ + ["id"], + ["encounter", "reference"], + ["subject", "reference"], + ["reasonReference", "reference"], + ], + "observation": [["id"], ["encounter", "reference"]], + "patient": [["id"]], +} class ResourceTableIdPos(IntEnum): @@ -62,25 +84,107 @@ def duckdb_args(args: list, tmp_path): return args + ["--db-type", "duckdb", "--database", f"{tmp_path}/duck.db"] +def ndjson_data_generator(source_dir: Path, target_dir: Path, iterations: int): + """Uses the test data as a template to create large datasets + + Rather than a complex find/replace operation, we're just appending ints cast + as strings to any FHIR resource ID we find. If you're doing something that + relies on exact length of resource IDs, consider a different approach.""" + + def update_nested_obj(id_path, obj, i): + """Recursively update an object val, a thing that should be a pandas feature""" + if len(id_path) == 1: + # if we get a float it's a NumPy nan, so we should just let it go + if isinstance(obj, float): + pass + elif isinstance(obj, list): + obj[0][id_path[0]] = obj[0][id_path[0]] + str(i) + else: + obj[id_path[0]] = obj[id_path[0]] + str(i) + else: + if isinstance(obj, list): + obj[0][id_path[0]] = update_nested_obj( + id_path[1:], obj[0][id_path[0]], i + ) + else: + obj[id_path[0]] = update_nested_obj(id_path[1:], obj[id_path[0]], i) + + return obj + + for key in ID_PATHS: + for filepath in [f for f in Path(source_dir / key).iterdir()]: + ref_df = pandas.read_json(filepath, lines=True) + output_df = pandas.DataFrame() + for i in range(0, iterations): + df = ref_df.copy(deep=True) + for id_path in ID_PATHS[key]: + if len(id_path) == 1: + if id_path[0] in df.columns: + df[id_path[0]] = df[id_path[0]] + str(i) + else: + if id_path[0] in df.columns: + # panda's deep copy is not recursive, so we have to do it + # again for nested objects + df[id_path[0]] = df[id_path[0]].map( + lambda x: update_nested_obj( + id_path[1:], copy.deepcopy(x), i + ) + ) + output_df = pandas.concat([output_df, df]) + # workaround for pandas/null/boolean casting issues + for null_bool_col in ["multipleBirthBoolean"]: + if null_bool_col in output_df.columns: + output_df[null_bool_col] = output_df[null_bool_col].replace( + {0.0: False} + ) + output_df = output_df.replace({numpy.nan: None}) + + write_path = Path(str(target_dir) + f"/{key}/{filepath.name}") + write_path.parent.mkdir(parents=True, exist_ok=True) + # pandas.to_json() fails due to the datamodel complexity, so we'll manually + # coerce to ndjson + out_dict = output_df.to_dict(orient="records") + with open(write_path, "w", encoding="UTF-8") as f: + for row in out_dict: + f.write(json.dumps(row, default=str) + "\n") + + @pytest.fixture -def mock_db(): +def mock_db(tmp_path): """Provides a DuckDatabaseBackend for local testing""" - with tempfile.TemporaryDirectory() as tmpdir: - db = create_db_backend( - { - "db_type": "duckdb", - "schema_name": f"{tmpdir}/duck.db", - "load_ndjson_dir": MOCK_DATA_DIR, - } - ) - yield db + db = create_db_backend( + { + "db_type": "duckdb", + "schema_name": f"{tmp_path}/duck.db", + "load_ndjson_dir": MOCK_DATA_DIR, + } + ) + yield db @pytest.fixture -def mock_db_core(mock_db): # pylint: disable=redefined-outer-name +def mock_db_core(tmp_path, mock_db): # pylint: disable=redefined-outer-name """Provides a DuckDatabaseBackend with the core study ran for local testing""" - builder = StudyBuilder(mock_db) + builder = StudyBuilder(mock_db, data_path=f"{tmp_path}/data_path") builder.clean_and_build_study( - f"{Path(__file__).parent.parent}/cumulus_library/studies/core" + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", True ) yield mock_db + + +@pytest.fixture +def mock_db_stats(tmp_path): + """Provides a DuckDatabaseBackend with a larger dataset for sampling stats""" + ndjson_data_generator(Path(MOCK_DATA_DIR), f"{tmp_path}/mock_data", 20) + db = create_db_backend( + { + "db_type": "duckdb", + "schema_name": f"{tmp_path}cumulus.duckdb", + "load_ndjson_dir": f"{tmp_path}/mock_data", + } + ) + builder = StudyBuilder(db, data_path=f"{tmp_path}/data_path") + builder.clean_and_build_study( + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", True + ) + yield db diff --git a/tests/test_cli.py b/tests/test_cli.py index 4b84f7e4..032e183b 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -114,7 +114,7 @@ def test_cli_path_mapping( args = duckdb_args(args, tmp_path) cli.main(cli_args=args) db = DuckDatabaseBackend(f"{tmp_path}/duck.db") - assert expected in db.cursor().execute("show tables").fetchone()[0] + assert (expected,) in db.cursor().execute("show tables").fetchall() @mock.patch.dict( @@ -140,6 +140,7 @@ def test_count_builder_mapping(mock_path, tmp_path): cli.main(cli_args=args) db = DuckDatabaseBackend(f"{tmp_path}/duck.db") assert [ + ("study_python_counts_valid__lib_transactions",), ("study_python_counts_valid__table1",), ("study_python_counts_valid__table2",), ] == db.cursor().execute("show tables").fetchall() @@ -194,7 +195,7 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a @pytest.mark.parametrize( "build_args,export_args,expected_tables", [ - (["build", "-t", "core"], ["export", "-t", "core"], 37), + (["build", "-t", "core"], ["export", "-t", "core"], 38), ( [ # checking that a study is loaded from a child directory of a user-defined path "build", @@ -204,9 +205,9 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a "tests/test_data/", ], ["export", "-t", "study_valid", "-s", "tests/test_data/"], - 1, + 2, ), - (["build", "-t", "vocab"], None, 2), + (["build", "-t", "vocab"], None, 3), ( [ # checking that a study is loaded from the directory of a user-defined path "build", @@ -216,7 +217,7 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a "tests/test_data/study_valid/", ], ["export", "-t", "study_valid", "-s", "tests/test_data/study_valid/"], - 1, + 2, ), ], ) diff --git a/tests/test_conftest.py b/tests/test_conftest.py new file mode 100644 index 00000000..7eeb68c7 --- /dev/null +++ b/tests/test_conftest.py @@ -0,0 +1,45 @@ +import json + +from pathlib import Path + +from tests.conftest import ndjson_data_generator, MOCK_DATA_DIR, ID_PATHS + + +def test_ndjson_data_generator(tmp_path): + iters = 20 + target = tmp_path + # generating outside tmp storage for debugging + # target = Path(MOCK_DATA_DIR + '/test_output') + ndjson_data_generator(Path(MOCK_DATA_DIR), target, iters) + for key in ID_PATHS: + for filepath in [f for f in Path(target / key).iterdir()]: + with open(filepath) as f: + first_new = json.loads(next(f)) + for line in f: + pass + last_new = json.loads(line) + with open(f"{Path(MOCK_DATA_DIR)}/{key}/{filepath.name}") as f: + first_line = next(f) + first_ref = json.loads(first_line) + # handling patient file of length 1: + line = first_line + for line in f: + pass + last_ref = json.loads(line) + for source in [[first_new, first_ref, 0], [last_new, last_ref, iters - 1]]: + for id_path in ID_PATHS[key]: + new_test = source[0] + ref_test = source[1] + for subkey in id_path: + if new_test is None: + break + if isinstance(new_test, list): + new_test = new_test[0].get(subkey) + ref_test = ref_test[0].get(subkey) + else: + new_test = new_test.get(subkey) + ref_test = ref_test.get(subkey) + if ref_test is not None: + assert new_test == ref_test + str(source[2]) + else: + assert ref_test == new_test diff --git a/tests/test_core.py b/tests/test_core.py index 92e1ce17..6a1e37a2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -61,7 +61,7 @@ def test_core_tables(mock_db_core, table): assert len(table_rows) == len(ref_table) -def test_core_count_missing_data(mock_db): +def test_core_count_missing_data(tmp_path, mock_db): null_code_class = { "id": None, "code": None, @@ -73,9 +73,11 @@ def test_core_count_missing_data(mock_db): cursor = mock_db.cursor() modify_resource_column(cursor, "encounter", "class", null_code_class) - builder = StudyBuilder(mock_db) + builder = StudyBuilder(mock_db, f"{tmp_path}/data_path/") builder.clean_and_build_study( - f"{Path(__file__).parent.parent}/cumulus_library/studies/core" + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", + f"{tmp_path}/data_path/", + False, ) table_rows = cursor.execute("SELECT * FROM core__count_encounter_month").fetchall() # For regenerating data if needed diff --git a/tests/test_data/psm/psm_config.toml b/tests/test_data/psm/psm_config.toml index 19f47b37..47c81f44 100644 --- a/tests/test_data/psm/psm_config.toml +++ b/tests/test_data/psm/psm_config.toml @@ -7,8 +7,8 @@ primary_ref = 'encounter_ref' count_ref = 'subject_ref' count_table = 'core__condition' dependent_variable = "example_diagnosis" -pos_sample_size = 5 -neg_sample_size = 5 +pos_sample_size = 50 +neg_sample_size = 500 seed = 1234567890 [join_cols_by_table.core__encounter] join_id = "encounter_ref" diff --git a/tests/test_data/psm/psm_config_no_optional.toml b/tests/test_data/psm/psm_config_no_optional.toml index cd079658..36a7b795 100644 --- a/tests/test_data/psm/psm_config_no_optional.toml +++ b/tests/test_data/psm/psm_config_no_optional.toml @@ -1,9 +1,10 @@ +config_type = "psm" classification_json = "dsm5_classifications.json" pos_source_table = "psm_test__psm_cohort" neg_source_table = "core__condition" -target_table = "psm_test__psm_encounter_covariate_no_optional" +target_table = "psm_test__psm_encounter_covariate" primary_ref = 'encounter_ref' dependent_variable = "example_diagnosis" -pos_sample_size = 5 -neg_sample_size = 5 +pos_sample_size = 50 +neg_sample_size = 500 seed = 1234567890 diff --git a/tests/test_psm.py b/tests/test_psm.py index 236cb473..c8904fb2 100644 --- a/tests/test_psm.py +++ b/tests/test_psm.py @@ -1,30 +1,44 @@ """ tests for propensity score matching generation """ +from datetime import datetime from pathlib import Path import pytest +from freezegun import freeze_time + from cumulus_library.cli import StudyBuilder from cumulus_library.statistics.psm import PsmBuilder +""" + ( + "psm_config_no_optional.toml", + 52, + 266, + {'encounter_ref': 'Encounter/03e34b19-2889-b828-792d-2a83400c55be10', 'example_diagnosis': '1', 'code': '33737001'}, + {'encounter_ref': 'Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19', 'example_diagnosis': '0', 'code': '195662009'}, + ), +""" + +@freeze_time("2024-01-01") @pytest.mark.parametrize( "toml_def,pos_set,neg_set,expected_first_record,expected_last_record", [ ( "psm_config.toml", - 7, - 7, + 52, + 266, { - "encounter_ref": "Encounter/11381dc6-0e06-da55-0735-d1e7bbf8bb35", + "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be10", "example_diagnosis": "1", "instance_count": 1, - "gender": "male", + "gender": "female", "race": "white", - "code": "44465007", + "code": "33737001", }, { - "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f1", + "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19", "example_diagnosis": "0", "instance_count": 1, "gender": "female", @@ -32,51 +46,52 @@ "code": "195662009", }, ), - ( - "psm_config_no_optional.toml", - 7, - 7, - { - "encounter_ref": "Encounter/11381dc6-0e06-da55-0735-d1e7bbf8bb35", - "example_diagnosis": "1", - "code": "44465007", - }, - { - "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f1", - "example_diagnosis": "0", - "code": "195662009", - }, - ), ], ) def test_psm_create( - mock_db_core, + tmp_path, + mock_db_stats, toml_def, pos_set, neg_set, expected_first_record, expected_last_record, ): - builder = StudyBuilder(mock_db_core) - psm = PsmBuilder(f"{Path(__file__).parent}/test_data/psm/{toml_def}") - mock_db_core.cursor().execute( - "create table core__psm_cohort as (select * from core__condition " - f"ORDER BY {psm.config.primary_ref} limit 10)" + builder = StudyBuilder(mock_db_stats, data_path=Path(tmp_path)) + psm = PsmBuilder( + f"{Path(__file__).parent}/test_data/psm/{toml_def}", Path(tmp_path) + ) + mock_db_stats.cursor().execute( + "create table psm_test__psm_cohort as (select * from core__condition " + f"ORDER BY {psm.config.primary_ref} limit 100)" ).df() - mock_db_core.cursor().execute("select * from core__psm_cohort").fetchall() - + mock_db_stats.cursor().execute("select * from psm_test__psm_cohort").fetchall() + safe_timestamp = ( + datetime.now() + .replace(microsecond=0) + .isoformat() + .replace(":", "_") + .replace("-", "_") + ) psm.execute_queries( - mock_db_core.pandas_cursor(), builder.schema_name, False, drop_table=True + mock_db_stats.pandas_cursor(), + builder.schema_name, + False, + drop_table=True, + table_suffix=safe_timestamp, ) df = ( - mock_db_core.cursor() + mock_db_stats.cursor() .execute("select * from psm_test__psm_encounter_covariate") .df() ) + print(df.columns) ed_series = df["example_diagnosis"].value_counts() assert ed_series.iloc[0] == neg_set assert ed_series.iloc[1] == pos_set first_record = df.iloc[0].to_dict() + print(first_record) assert first_record == expected_first_record last_record = df.iloc[neg_set + pos_set - 1].to_dict() + print(last_record) assert last_record == expected_last_record diff --git a/tests/test_psm_templates.py b/tests/test_psm_templates.py index 57f81aed..5f1d2550 100644 --- a/tests/test_psm_templates.py +++ b/tests/test_psm_templates.py @@ -53,12 +53,13 @@ def test_get_distinct_ids( @pytest.mark.parametrize( - "target,pos_source,neg_source,primary_ref,dep_var,join_cols_by_table,count_ref,count_table,expected,raises", + "target,pos_source,neg_source,table_suffix,primary_ref,dep_var,join_cols_by_table,count_ref,count_table,expected,raises", [ ( "target", "pos_table", "neg_table", + "2024_01_01_11_11_11", "subject_id", "has_flu", {}, @@ -69,7 +70,7 @@ def test_get_distinct_ids( sample_cohort."subject_id", sample_cohort."has_flu", neg_table.code - FROM "pos_table_sampled_ids" AS sample_cohort, + FROM "pos_table_sampled_ids_2024_01_01_11_11_11" AS sample_cohort, "neg_table", WHERE sample_cohort."subject_id" = "neg_table"."subject_id" @@ -81,6 +82,7 @@ def test_get_distinct_ids( "target", "pos_table", "neg_table", + "2024_01_01_11_11_11", "subject_id", "has_flu", { @@ -103,7 +105,7 @@ def test_get_distinct_ids( "join_table"."a", "join_table"."b" AS "c", neg_table.code - FROM "pos_table_sampled_ids" AS sample_cohort, + FROM "pos_table_sampled_ids_2024_01_01_11_11_11" AS sample_cohort, "neg_table", "join_table" WHERE @@ -118,6 +120,7 @@ def test_get_distinct_ids( "target", "pos_table", "neg_table", + "2024_01_01_11_11_11", "subject_id", "has_flu", {}, @@ -132,6 +135,7 @@ def test_create_covariate_table( target, pos_source, neg_source, + table_suffix, primary_ref, dep_var, join_cols_by_table, @@ -145,6 +149,7 @@ def test_create_covariate_table( target, pos_source, neg_source, + table_suffix, primary_ref, dep_var, join_cols_by_table, diff --git a/tests/test_study_parser.py b/tests/test_study_parser.py index dc7c3643..6ccf3b25 100644 --- a/tests/test_study_parser.py +++ b/tests/test_study_parser.py @@ -1,11 +1,14 @@ """ tests for study parser against mocks in test_data """ import builtins import pathlib + from contextlib import nullcontext as does_not_raise +from pathlib import Path from unittest import mock import pytest +from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS from cumulus_library.study_parser import StudyManifestParser, StudyManifestParsingError from tests.test_data.parser_mock_data import get_mock_toml, mock_manifests @@ -75,81 +78,111 @@ def test_manifest_data(manifest_key, raises): @pytest.mark.parametrize( - "schema,verbose,prefix,confirm,query_res,raises", + "schema,verbose,prefix,confirm,target,raises", [ - ("schema", True, None, None, "study_valid__table", does_not_raise()), - ("schema", False, None, None, "study_valid__table", does_not_raise()), - ("schema", None, None, None, "study_valid__table", does_not_raise()), - (None, True, None, None, [], pytest.raises(ValueError)), - ("schema", None, None, None, "study_valid__etl_table", does_not_raise()), - ("schema", None, None, None, "study_valid__nlp_table", does_not_raise()), - ("schema", None, None, None, "study_valid__lib_table", does_not_raise()), - ("schema", None, None, None, "study_valid__lib", does_not_raise()), - ("schema", None, "foo", "y", "foo_table", does_not_raise()), - ("schema", None, "foo", "n", "foo_table", pytest.raises(SystemExit)), + ("main", True, None, None, "study_valid__table", does_not_raise()), + ("main", False, None, None, "study_valid__table", does_not_raise()), + ("main", None, None, None, "study_valid__table", does_not_raise()), + (None, True, None, None, None, pytest.raises(ValueError)), + ("main", None, None, None, "study_valid__etl_table", does_not_raise()), + ("main", None, None, None, "study_valid__nlp_table", does_not_raise()), + ("main", None, None, None, "study_valid__lib_table", does_not_raise()), + ("main", None, None, None, "study_valid__lib", does_not_raise()), + ("main", None, "foo", "y", "foo_table", does_not_raise()), + ("main", None, "foo", "n", "foo_table", pytest.raises(SystemExit)), ], ) -@mock.patch("cumulus_library.helper.query_console_output") -def test_clean_study(mock_output, schema, verbose, prefix, confirm, query_res, raises): +def test_clean_study(mock_db, schema, verbose, prefix, confirm, target, raises): with raises: + protected_strs = [x.value for x in PROTECTED_TABLE_KEYWORDS] with mock.patch.object(builtins, "input", lambda _: confirm): - mock_cursor = mock.MagicMock() - mock_cursor.fetchall.return_value = [[query_res]] parser = StudyManifestParser("./tests/test_data/study_valid/") - tables = parser.clean_study(mock_cursor, schema, verbose, prefix=prefix) - - if "study_valid__table" not in query_res and prefix is None: - assert not tables + if target is not None: + mock_db.cursor().execute(f"CREATE TABLE {target} (test int);") + parser.clean_study(mock_db.cursor(), schema, verbose=verbose, prefix=prefix) + remaining_tables = ( + mock_db.cursor() + .execute(f"select distinct(table_name) from information_schema.tables") + .fetchall() + ) + if any(x in target for x in protected_strs): + assert (target,) in remaining_tables else: - assert tables == [[query_res, "VIEW"]] - if prefix is not None: - assert prefix in mock_cursor.execute.call_args.args[0] - else: - assert "study_valid__" in mock_cursor.execute.call_args.args[0] - assert mock_output.is_called() + assert (target,) not in remaining_tables + + +""" + ("./tests/test_data/study_python_valid/", True, ('study_python_valid__table',), does_not_raise()), + + ( + "./tests/test_data/study_python_no_subclass/", + True, + (), + does_not_raise(), + ), + +""" @pytest.mark.parametrize( - "path,verbose,raises", + "study_path,verbose,expects,raises", [ - ("./tests/test_data/study_valid/", True, does_not_raise()), - ("./tests/test_data/study_valid/", False, does_not_raise()), - ("./tests/test_data/study_valid/", None, does_not_raise()), - ("./tests/test_data/study_wrong_prefix/", None, pytest.raises(SystemExit)), - ("./tests/test_data/study_python_valid/", True, does_not_raise()), ( - "./tests/test_data/study_python_no_subclass/", + "./tests/test_data/study_valid/", True, - pytest.raises(StudyManifestParsingError), + ("study_valid__table",), + does_not_raise(), + ), + ( + "./tests/test_data/study_valid/", + False, + ("study_valid__table",), + does_not_raise(), + ), + ( + "./tests/test_data/study_valid/", + None, + ("study_valid__table",), + does_not_raise(), + ), + ("./tests/test_data/study_wrong_prefix/", None, [], pytest.raises(SystemExit)), + ( + "./tests/test_data/study_invalid_no_dunder/", + True, + (), + pytest.raises(SystemExit), ), - ("./tests/test_data/study_invalid_no_dunder/", True, pytest.raises(SystemExit)), ( "./tests/test_data/study_invalid_two_dunder/", True, + (), pytest.raises(SystemExit), ), ( "./tests/test_data/study_invalid_reserved_word/", True, + (), pytest.raises(SystemExit), ), ], ) -@mock.patch("cumulus_library.helper.query_console_output") -def test_build_study(mock_output, path, verbose, raises): +def test_build_study(mock_db, study_path, verbose, expects, raises): with raises: - mock_cursor = mock.MagicMock() - parser = StudyManifestParser(path) - parser.run_table_builder(mock_cursor, verbose) - queries = parser.build_study(mock_cursor, verbose) - if "python" not in path: - assert "CREATE TABLE" in queries[0][0] - assert mock_output.is_called() - - -def test_export_study(monkeypatch): - mock_cursor = mock.MagicMock() - parser = StudyManifestParser("./tests/test_data/study_valid/") - monkeypatch.setattr(pathlib, "PosixPath", mock.MagicMock()) - queries = parser.export_study(mock_cursor, "./path") - assert queries == ["select * from study_valid__table"] + parser = StudyManifestParser(study_path) + parser.build_study(mock_db.cursor(), verbose) + tables = ( + mock_db.cursor() + .execute("SELECT distinct(table_name) FROM information_schema.tables ") + .fetchall() + ) + assert expects in tables + + +def test_export_study(tmp_path, mock_db_core): + parser = StudyManifestParser( + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", + data_path=f"{tmp_path}/export", + ) + parser.export_study(mock_db_core, f"{tmp_path}/export") + for file in Path(f"{tmp_path}/export").glob("*.*"): + assert file in parser.get_export_table_list() From 770729c98db771c67fe6110961da01d56f8b0386 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Wed, 20 Dec 2023 16:18:28 -0500 Subject: [PATCH 03/13] Stats test coverage --- cumulus_library/cli.py | 2 - cumulus_library/cli_parser.py | 2 +- cumulus_library/protected_table_builder.py | 1 - cumulus_library/study_parser.py | 2 +- tests/conftest.py | 10 +- tests/test_cli.py | 74 ++++++++- tests/test_data/psm/psm_cohort.sql | 2 +- tests/test_psm.py | 25 +-- tests/test_study_parser.py | 178 ++++++++++++++++++--- 9 files changed, 251 insertions(+), 45 deletions(-) diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index ffafb762..541e0fe5 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -135,7 +135,6 @@ def clean_and_build_study( self.schema_name, verbose=self.verbose, stats_build=stats_build, - data_path=self.data_path, ) self.update_transactions(studyparser.get_study_prefix(), "finished") except SystemExit as exit: @@ -378,7 +377,6 @@ def main(cli_args=None): if args.get("data_path"): args["data_path"] = get_abs_posix_path(args["data_path"]) - return run_cli(args) diff --git a/cumulus_library/cli_parser.py b/cumulus_library/cli_parser.py index b5f15bd3..409f2545 100644 --- a/cumulus_library/cli_parser.py +++ b/cumulus_library/cli_parser.py @@ -169,8 +169,8 @@ def get_parser() -> argparse.ArgumentParser: add_table_builder_argument(build) add_study_dir_argument(build) add_verbose_argument(build) - add_data_path_argument(build) add_db_config(build) + add_data_path_argument(build) build.add_argument( "--statistics", action="store_true", diff --git a/cumulus_library/protected_table_builder.py b/cumulus_library/protected_table_builder.py index b8deb85b..e208fdc5 100644 --- a/cumulus_library/protected_table_builder.py +++ b/cumulus_library/protected_table_builder.py @@ -25,7 +25,6 @@ class ProtectedTableBuilder(BaseTableBuilder): def prepare_queries( self, cursor: object, schema: str, study_name: str, study_stats: dict ): - print("hi") self.queries.append( get_ctas_empty_query( schema, diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index 7e7bfcd6..5c07e7cf 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -207,6 +207,7 @@ def clean_study( FROM {drop_prefix}{PROTECTED_TABLES.STATISTICS.value} WHERE study_name = '{display_prefix}'""" ).fetchall() + print(protected_list) for protected_tuple in protected_list: if protected_tuple in tuple_list: tuple_list.remove(protected_tuple) @@ -395,7 +396,6 @@ def run_statistics_builders( schema: str, verbose: bool = False, stats_build: bool = False, - data_path: PosixPath = None, ) -> None: """Loads statistics modules from toml definitions and executes diff --git a/tests/conftest.py b/tests/conftest.py index 893f961e..60b1511c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -62,14 +62,20 @@ def modify_resource_column( cursor.execute(f"CREATE VIEW {table} AS SELECT * from {table}_{col}_df") -def duckdb_args(args: list, tmp_path): +def duckdb_args(args: list, tmp_path, stats=False): """Convenience function for adding duckdb args to a CLI mock""" + if stats: + ndjson_data_generator(Path(MOCK_DATA_DIR), Path(f"{tmp_path}/stats_db"), 20) + target = f"{tmp_path}/stats_db" + else: + target = f"{MOCK_DATA_DIR}" + if args[0] == "build": return args + [ "--db-type", "duckdb", "--load-ndjson-dir", - f"{MOCK_DATA_DIR}", + target, "--database", f"{tmp_path}/duck.db", ] diff --git a/tests/test_cli.py b/tests/test_cli.py index 032e183b..0bbcd4d8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,4 +1,5 @@ """ tests for the cli interface to studies """ +import builtins import glob import os import sysconfig @@ -110,7 +111,6 @@ def test_cli_path_mapping( "study_python_valid": "study_python_valid", }, } - sysconfig.get_path("purelib") args = duckdb_args(args, tmp_path) cli.main(cli_args=args) db = DuckDatabaseBackend(f"{tmp_path}/duck.db") @@ -171,6 +171,15 @@ def test_count_builder_mapping(mock_path, tmp_path): ], "foo", ), + ( + [ + "clean", + "-t", + "core", + "--statistics", + ], + "core__", + ), ], ) def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-argument @@ -179,12 +188,14 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a cli_args=duckdb_args(["build", "-t", "core", "--database", "test"], tmp_path) ) with does_not_raise(): - cli.main( - cli_args=args + ["--db-type", "duckdb", "--database", f"{tmp_path}/duck.db"] - ) - db = DuckDatabaseBackend(f"{tmp_path}/duck.db") - for table in db.cursor().execute("show tables").fetchall(): - assert expected not in table + with mock.patch.object(builtins, "input", lambda _: "y"): + cli.main( + cli_args=args + + ["--db-type", "duckdb", "--database", f"{tmp_path}/duck.db"] + ) + db = DuckDatabaseBackend(f"{tmp_path}/duck.db") + for table in db.cursor().execute("show tables").fetchall(): + assert expected not in table @mock.patch("builtins.open", MockVocabBsv().open) @@ -209,12 +220,14 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a ), (["build", "-t", "vocab"], None, 3), ( - [ # checking that a study is loaded from the directory of a user-defined path + [ # checking that a study is loaded from the directory of a user-defined path. + # we're also validating that the CLI accpes the statistics keyword, though "build", "-t", "study_valid", "-s", "tests/test_data/study_valid/", + "--statistics", ], ["export", "-t", "study_valid", "-s", "tests/test_data/study_valid/"], 2, @@ -253,6 +266,51 @@ def test_cli_executes_queries(tmp_path, build_args, export_args, expected_tables assert any(export_table in x for x in csv_files) +@mock.patch.dict( + os.environ, + clear=True, +) +def test_cli_stats_rebuild(tmp_path): + """Validates statistics build behavior + + Since this is a little obtuse - we are checking: + - that stats builds run at all + - that a results table is created on the first run + - that a results table is :not: created on the second run + - that a results table is created when we explicitly ask for one with a CLI flag + """ + + cli.main( + cli_args=duckdb_args( + ["build", "-t", "core", "--database", "test"], tmp_path, stats=True + ) + ) + arg_list = [ + "build", + "-s", + "./tests/test_data", + "-t", + "psm", + "--db-type", + "duckdb", + "--database", + f"{tmp_path}/duck.db", + ] + cli.main(cli_args=arg_list + [f"{tmp_path}/export"]) + cli.main(cli_args=arg_list + [f"{tmp_path}/export"]) + cli.main(cli_args=arg_list + [f"{tmp_path}/export", "--statistics"]) + db = DuckDatabaseBackend(f"{tmp_path}/duck.db") + expected = ( + db.cursor() + .execute( + "SELECT table_name FROM information_schema.tables " + "WHERE table_name LIKE 'psm_test__psm_encounter_covariate_%'" + ) + .fetchall() + ) + assert len(expected) == 2 + + @mock.patch.dict( os.environ, clear=True, diff --git a/tests/test_data/psm/psm_cohort.sql b/tests/test_data/psm/psm_cohort.sql index e2908f7f..bfcacb78 100644 --- a/tests/test_data/psm/psm_cohort.sql +++ b/tests/test_data/psm/psm_cohort.sql @@ -1,3 +1,3 @@ CREATE TABLE psm_test__psm_cohort AS ( - SELECT * FROM core__condition ORDER BY condition_id DESC LIMIT 7 --noqa: AM04 + SELECT * FROM core__condition ORDER BY condition_id DESC LIMIT 100 --noqa: AM04 ); diff --git a/tests/test_psm.py b/tests/test_psm.py index c8904fb2..897e5565 100644 --- a/tests/test_psm.py +++ b/tests/test_psm.py @@ -10,16 +10,6 @@ from cumulus_library.cli import StudyBuilder from cumulus_library.statistics.psm import PsmBuilder -""" - ( - "psm_config_no_optional.toml", - 52, - 266, - {'encounter_ref': 'Encounter/03e34b19-2889-b828-792d-2a83400c55be10', 'example_diagnosis': '1', 'code': '33737001'}, - {'encounter_ref': 'Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19', 'example_diagnosis': '0', 'code': '195662009'}, - ), -""" - @freeze_time("2024-01-01") @pytest.mark.parametrize( @@ -46,6 +36,21 @@ "code": "195662009", }, ), + ( + "psm_config_no_optional.toml", + 52, + 266, + { + "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be10", + "example_diagnosis": "1", + "code": "33737001", + }, + { + "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19", + "example_diagnosis": "0", + "code": "195662009", + }, + ), ], ) def test_psm_create( diff --git a/tests/test_study_parser.py b/tests/test_study_parser.py index 6ccf3b25..5b6d6f3b 100644 --- a/tests/test_study_parser.py +++ b/tests/test_study_parser.py @@ -8,7 +8,7 @@ import pytest -from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS +from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS, PROTECTED_TABLES from cumulus_library.study_parser import StudyManifestParser, StudyManifestParsingError from tests.test_data.parser_mock_data import get_mock_toml, mock_manifests @@ -78,28 +78,56 @@ def test_manifest_data(manifest_key, raises): @pytest.mark.parametrize( - "schema,verbose,prefix,confirm,target,raises", + "schema,verbose,prefix,confirm,stats,target,raises", [ - ("main", True, None, None, "study_valid__table", does_not_raise()), - ("main", False, None, None, "study_valid__table", does_not_raise()), - ("main", None, None, None, "study_valid__table", does_not_raise()), - (None, True, None, None, None, pytest.raises(ValueError)), - ("main", None, None, None, "study_valid__etl_table", does_not_raise()), - ("main", None, None, None, "study_valid__nlp_table", does_not_raise()), - ("main", None, None, None, "study_valid__lib_table", does_not_raise()), - ("main", None, None, None, "study_valid__lib", does_not_raise()), - ("main", None, "foo", "y", "foo_table", does_not_raise()), - ("main", None, "foo", "n", "foo_table", pytest.raises(SystemExit)), + ("main", True, None, None, False, "study_valid__table", does_not_raise()), + ("main", False, None, None, False, "study_valid__table", does_not_raise()), + ("main", None, None, None, False, "study_valid__table", does_not_raise()), + (None, True, None, None, False, None, pytest.raises(SystemExit)), + ("main", None, None, None, False, "study_valid__etl_table", does_not_raise()), + ("main", None, None, None, False, "study_valid__nlp_table", does_not_raise()), + ("main", None, None, None, False, "study_valid__lib_table", does_not_raise()), + ("main", None, None, None, False, "study_valid__lib", does_not_raise()), + ("main", None, "foo", "y", False, "foo_table", does_not_raise()), + ("main", None, "foo", "n", False, "foo_table", pytest.raises(SystemExit)), + ("main", True, None, "y", True, "study_valid__table", does_not_raise()), + ( + "main", + True, + None, + "n", + True, + "study_valid__table", + pytest.raises(SystemExit), + ), ], ) -def test_clean_study(mock_db, schema, verbose, prefix, confirm, target, raises): +def test_clean_study(mock_db, schema, verbose, prefix, confirm, stats, target, raises): with raises: protected_strs = [x.value for x in PROTECTED_TABLE_KEYWORDS] with mock.patch.object(builtins, "input", lambda _: confirm): parser = StudyManifestParser("./tests/test_data/study_valid/") + parser.run_protected_table_builder(mock_db.cursor(), schema) + + # We're mocking stats tables since creating them programmatically + # is very slow and we're trying a lot of conditions + mock_db.cursor().execute( + f"CREATE TABLE {parser.get_study_prefix()}__" + f"{PROTECTED_TABLES.STATISTICS.value} " + "AS SELECT 'study_valid' as study_name, " + "'study_valid__123' AS table_name" + ) + mock_db.cursor().execute("CREATE TABLE study_valid__123 (test int)") + if target is not None: mock_db.cursor().execute(f"CREATE TABLE {target} (test int);") - parser.clean_study(mock_db.cursor(), schema, verbose=verbose, prefix=prefix) + parser.clean_study( + mock_db.cursor(), + schema, + verbose=verbose, + prefix=prefix, + stats_clean=stats, + ) remaining_tables = ( mock_db.cursor() .execute(f"select distinct(table_name) from information_schema.tables") @@ -109,19 +137,92 @@ def test_clean_study(mock_db, schema, verbose, prefix, confirm, target, raises): assert (target,) in remaining_tables else: assert (target,) not in remaining_tables + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.TRANSACTIONS.value}", + ) in remaining_tables + if stats: + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + ) not in remaining_tables + assert ("study_valid__123",) not in remaining_tables + else: + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + ) in remaining_tables + assert ("study_valid__123",) in remaining_tables -""" - ("./tests/test_data/study_python_valid/", True, ('study_python_valid__table',), does_not_raise()), +@pytest.mark.parametrize( + "study_path,stats", + [ + ("./tests/test_data/study_valid/", False), + ("./tests/test_data/psm/", True), + ], +) +def test_run_protected_table_builder(mock_db, study_path, stats): + parser = StudyManifestParser(study_path) + parser.run_protected_table_builder(mock_db.cursor(), "main") + tables = ( + mock_db.cursor() + .execute("SELECT distinct(table_name) FROM information_schema.tables ") + .fetchall() + ) + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.TRANSACTIONS.value}", + ) in tables + if stats: + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + ) in tables + else: + assert ( + f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + ) not in tables + +@pytest.mark.parametrize( + "study_path,verbose,expects,raises", + [ + ( + "./tests/test_data/study_python_valid/", + True, + ("study_python_valid__table",), + does_not_raise(), + ), + ( + "./tests/test_data/study_python_valid/", + False, + ("study_python_valid__table",), + does_not_raise(), + ), + ( + "./tests/test_data/study_python_valid/", + None, + ("study_python_valid__table",), + does_not_raise(), + ), ( "./tests/test_data/study_python_no_subclass/", True, (), - does_not_raise(), + pytest.raises(StudyManifestParsingError), ), - -""" + ], +) +def test_table_builder(mock_db, study_path, verbose, expects, raises): + with raises: + parser = StudyManifestParser(study_path) + parser.run_table_builder( + mock_db.cursor(), + "main", + verbose, + ) + tables = ( + mock_db.cursor() + .execute("SELECT distinct(table_name) FROM information_schema.tables ") + .fetchall() + ) + assert expects in tables @pytest.mark.parametrize( @@ -178,6 +279,45 @@ def test_build_study(mock_db, study_path, verbose, expects, raises): assert expects in tables +@pytest.mark.parametrize( + "study_path,stats,expects,raises", + [ + ( + "./tests/test_data/psm/", + False, + (f"psm_test__psm_encounter_covariate",), + does_not_raise(), + ), + ( + "./tests/test_data/psm/", + True, + (f"psm_test__psm_encounter_covariate",), + does_not_raise(), + ), + ], +) +def test_run_statistics_builders( + tmp_path, mock_db_stats, study_path, stats, expects, raises +): + with raises: + parser = StudyManifestParser(study_path, data_path=tmp_path) + parser.run_protected_table_builder(mock_db_stats.cursor(), "main") + parser.build_study(mock_db_stats.cursor(), "main") + parser.run_statistics_builders( + mock_db_stats.cursor(), "main", stats_build=stats + ) + tables = ( + mock_db_stats.cursor() + .execute("SELECT distinct(table_name) FROM information_schema.tables") + .fetchall() + ) + print(tables) + if stats: + assert expects in tables + else: + assert expects not in tables + + def test_export_study(tmp_path, mock_db_core): parser = StudyManifestParser( f"{Path(__file__).parent.parent}/cumulus_library/studies/core", From efdee2f390a60f7f8c7db0c90393815f65b1ecb9 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Wed, 20 Dec 2023 16:41:03 -0500 Subject: [PATCH 04/13] Docs update, linting --- .../studies/template/manifest.toml | 25 +++++++++---- docs/statistics.md | 35 ++++++++++++++++++- docs/statistics/propensity-score-matching.md | 11 ++---- tests/test_templates.py | 6 ++-- 4 files changed, 58 insertions(+), 19 deletions(-) diff --git a/cumulus_library/studies/template/manifest.toml b/cumulus_library/studies/template/manifest.toml index 49b9f9c7..bfe197c2 100644 --- a/cumulus_library/studies/template/manifest.toml +++ b/cumulus_library/studies/template/manifest.toml @@ -4,6 +4,15 @@ # be the same name as the folder the study definition is in. study_prefix = "template" +# For most use cases, this should not be required, but if you need to programmatically +# build tables, you can provide a list of files implementing BaseTableBuilder. +# See vocab and core studies for examples of this pattern. These run before +# any SQL execution +# [table_builder_config] +# file_names = [ +# "my_table_builder.py", +# ] + # The following section describes all tables that should be generated directly # from SQL files. [sql_config] @@ -38,10 +47,14 @@ export_list = [ # "count.py" # ] -# For most use cases, this should not be required, but if you need to programmatically -# build tables, you can provide a list of files implementing BaseTableBuilder. -# See vocab and core studies for examples of this pattern -# [table_builder_config] -# file_names = [ -# "my_table_builder.py", +# For more specialized statistics, we provide a toml-based config entrypoint. The +# details of these configs will vary, depending on which statistical method you're +# invoking. For more details, see the statistics section of the docs for a list of +# supported approaches. +# These will run last, so all the data in your study will exist by the time these +# are invoked. +# [statistics_config] +# file_names = +# [ +# "psm_config.toml" # ] diff --git a/docs/statistics.md b/docs/statistics.md index 1674ee26..a9a51d56 100644 --- a/docs/statistics.md +++ b/docs/statistics.md @@ -12,4 +12,37 @@ has_children: true This page contains detailed documentation on statistics utilities provided for use in Cumulus studies. -- [Propensity Score Matching](statistics/propensity-score-matching.md). \ No newline at end of file +## Specific stats modules + +- [Propensity Score Matching](statistics/propensity-score-matching.md). + +## General usage guidelines + +You can invoke a statistic task from your study's manifest the same way that you +would run SQL or python files - the only difference is that you point it at another +toml file, which allows stats configs to have different input parameters depending +on the analysis you're trying to perform. + +In your manifest, you'd add a section like this: +```toml +[statistics_config] +file_names = [ + "psm_config.toml" +] +``` + +We'll use this as a way to load statistics configurations. Since some of these +statistical methods may be quasi-experimental (i.e. perform a random sampling), +we will persist these outputs outside of a study lifecycle. + +The first time you run a `build` against a study with a statistics config that +has not previously been run before, it will be executed, and it should generate +a table in your database with a timestamp, along with a view that points to that +table. Subsequent updates will not replace that data, unless you provide the +`--statistics` argument. If you do, it will create a new timestamped table, +point the view to your newest table, and leave the old one in place in case +you need to get those results back at a later point in time. + +When you `clean` a study with statistics, by default all statistics artifacts +will be ignored, and persist in the database. If you want to remove these, +the `--statistics` argument will remove all these stats artifacts. \ No newline at end of file diff --git a/docs/statistics/propensity-score-matching.md b/docs/statistics/propensity-score-matching.md index aebbd322..c9774a60 100644 --- a/docs/statistics/propensity-score-matching.md +++ b/docs/statistics/propensity-score-matching.md @@ -31,14 +31,9 @@ The expected workflow looks something like this: ## Configuring a PSM task -You can configure a PSM task the same way that you would configure python/counts table -infrastructure in your manifest.toml. - -TODO: update after this is implemented with example of usage - -The PSM config you are referencing above is expected to contain a number of field -definitions. We :strongly: recommend starting from the below template, which contains -details on the expectations of each value. +The PSM config you reference in your study manifest is expected to contain a number of +field definitions. We :strongly: recommend starting from the below template, which +contains details on the expectations of each value. ```toml # This is a config file for generating a propensity score matching (PSM) definition. diff --git a/tests/test_templates.py b/tests/test_templates.py index f06935fa..d5e91f20 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -244,7 +244,7 @@ def test_create_view_query_creation(): "expected,schema,table,cols,types", [ ( - """CREATE TABLE IF NOT EXISTS "test_schema"."test_table" + """CREATE TABLE IF NOT EXISTS "test_schema"."test_table" AS ( SELECT * FROM ( VALUES @@ -258,7 +258,7 @@ def test_create_view_query_creation(): [], ), ( - """CREATE TABLE IF NOT EXISTS "test_schema"."test_table" + """CREATE TABLE IF NOT EXISTS "test_schema"."test_table" AS ( SELECT * FROM ( VALUES @@ -277,8 +277,6 @@ def test_ctas_empty_query_creation(expected, schema, table, cols, types): query = get_ctas_empty_query( schema_name=schema, table_name=table, table_cols=cols, table_cols_types=types ) - with open("output.sql", "w") as f: - f.write(query) assert query == expected From 1f58bdf08bb64061fa6f9cfc1661e690492d9490 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Thu, 21 Dec 2023 10:32:24 -0500 Subject: [PATCH 05/13] pylint --- cumulus_library/base_table_builder.py | 6 +++--- cumulus_library/cli.py | 11 +++++------ cumulus_library/databases.py | 1 - cumulus_library/enums.py | 4 ++-- cumulus_library/protected_table_builder.py | 11 +++++------ cumulus_library/study_parser.py | 18 +++++++++--------- tests/test_psm_templates.py | 3 ++- tests/test_study_parser.py | 18 +++++++++--------- 8 files changed, 35 insertions(+), 37 deletions(-) diff --git a/cumulus_library/base_table_builder.py b/cumulus_library/base_table_builder.py index 2e04bd47..b5e56466 100644 --- a/cumulus_library/base_table_builder.py +++ b/cumulus_library/base_table_builder.py @@ -39,8 +39,8 @@ def execute_queries( cursor: DatabaseCursor, schema: str, verbose: bool, - drop_table: bool = False, *args, + drop_table: bool = False, **kwargs, ): """Executes queries set up by a prepare_queries call @@ -77,7 +77,7 @@ def execute_queries( query_console_output(verbose, query, progress, task) try: cursor.execute(query) - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught sys.exit(e) self.post_execution(cursor, schema, verbose, drop_table, *args, **kwargs) @@ -86,8 +86,8 @@ def post_execution( cursor: DatabaseCursor, schema: str, verbose: bool, - drop_table: bool = False, *args, + drop_table: bool = False, **kwargs, ): """Hook for any additional actions to run after execute_queries""" diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 541e0fe5..7e5297a1 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -16,11 +16,10 @@ from cumulus_library import __version__ from cumulus_library.cli_parser import get_parser from cumulus_library.databases import ( - AthenaDatabaseBackend, DatabaseBackend, create_db_backend, ) -from cumulus_library.enums import PROTECTED_TABLES +from cumulus_library.enums import ProtectedTables from cumulus_library.protected_table_builder import TRANSACTIONS_COLS from cumulus_library.study_parser import StudyManifestParser from cumulus_library.template_sql.templates import get_insert_into_query @@ -42,7 +41,7 @@ def __init__(self, db: DatabaseBackend, data_path: str): def update_transactions(self, prefix: str, status: str): self.cursor.execute( get_insert_into_query( - f"{prefix}__{PROTECTED_TABLES.TRANSACTIONS.value}", + f"{prefix}__{ProtectedTables.TRANSACTIONS.value}", TRANSACTIONS_COLS, [ [ @@ -137,8 +136,8 @@ def clean_and_build_study( stats_build=stats_build, ) self.update_transactions(studyparser.get_study_prefix(), "finished") - except SystemExit as exit: - raise exit + except SystemExit as e: + raise e except Exception as e: self.update_transactions(studyparser.get_study_prefix(), "error") raise e @@ -297,7 +296,7 @@ def run_cli(args: Dict): for target in args["target"]: if args["builder"]: builder.run_single_table_builder( - study_dict[target], args["builder"], args["stats_build"] + study_dict[target], args["builder"] ) else: builder.clean_and_build_study( diff --git a/cumulus_library/databases.py b/cumulus_library/databases.py index 3857151e..8e77dcad 100644 --- a/cumulus_library/databases.py +++ b/cumulus_library/databases.py @@ -12,7 +12,6 @@ import datetime import json import os -import sys import warnings from pathlib import Path from typing import Optional, Protocol, Union diff --git a/cumulus_library/enums.py b/cumulus_library/enums.py index d65af4e4..1c12714a 100644 --- a/cumulus_library/enums.py +++ b/cumulus_library/enums.py @@ -2,7 +2,7 @@ from enum import Enum -class PROTECTED_TABLE_KEYWORDS(Enum): +class ProtectedTableKeywords(Enum): """Tables with a pattern like '_{keyword}_' are not manually dropped.""" ETL = "etl" @@ -10,7 +10,7 @@ class PROTECTED_TABLE_KEYWORDS(Enum): NLP = "nlp" -class PROTECTED_TABLES(Enum): +class ProtectedTables(Enum): """Tables created by cumulus for persistence outside of study rebuilds""" STATISTICS = "lib_statistics" diff --git a/cumulus_library/protected_table_builder.py b/cumulus_library/protected_table_builder.py index e208fdc5..b047920e 100644 --- a/cumulus_library/protected_table_builder.py +++ b/cumulus_library/protected_table_builder.py @@ -1,11 +1,8 @@ """ Builder for creating tables for tracking state/logging changes""" -import datetime - from cumulus_library.base_table_builder import BaseTableBuilder -from cumulus_library.enums import PROTECTED_TABLES +from cumulus_library.enums import ProtectedTables from cumulus_library.template_sql.templates import ( get_ctas_empty_query, - get_create_view_query, ) TRANSACTIONS_COLS = ["study_name", "library_version", "status", "event_time"] @@ -20,6 +17,8 @@ class ProtectedTableBuilder(BaseTableBuilder): + """Builder for tables that persist across study clean/build actions""" + display_text = "Creating/updating system tables..." def prepare_queries( @@ -28,7 +27,7 @@ def prepare_queries( self.queries.append( get_ctas_empty_query( schema, - f"{study_name}__{PROTECTED_TABLES.TRANSACTIONS.value}", + f"{study_name}__{ProtectedTables.TRANSACTIONS.value}", # while it may seem redundant, study name is included for ease # of constructing a view of multiple transaction tables TRANSACTIONS_COLS, @@ -39,7 +38,7 @@ def prepare_queries( self.queries.append( get_ctas_empty_query( schema, - f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}", + f"{study_name}__{ProtectedTables.STATISTICS.value}", # same redundancy note about study_name, and also view_name, applies here STATISTICS_COLS, [ diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index 5c07e7cf..6e6a2fbf 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -14,7 +14,7 @@ from cumulus_library import __version__ from cumulus_library.base_table_builder import BaseTableBuilder from cumulus_library.databases import DatabaseBackend, DatabaseCursor -from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS, PROTECTED_TABLES +from cumulus_library.enums import ProtectedTableKeywords, ProtectedTables from cumulus_library.errors import StudyManifestParsingError from cumulus_library.helper import ( query_console_output, @@ -200,11 +200,11 @@ def clean_study( for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: tuple_list = cursor.execute(query_and_type[0]).fetchall() if ( - f"{drop_prefix}{PROTECTED_TABLES.STATISTICS.value}", + f"{drop_prefix}{ProtectedTables.STATISTICS.value}", ) in tuple_list and not stats_clean: protected_list = cursor.execute( f"""SELECT {(query_and_type[1]).lower()}_name - FROM {drop_prefix}{PROTECTED_TABLES.STATISTICS.value} + FROM {drop_prefix}{ProtectedTables.STATISTICS.value} WHERE study_name = '{display_prefix}'""" ).fetchall() print(protected_list) @@ -233,7 +233,7 @@ def clean_study( (f"_{word.value}_") in view_table[0] or view_table[0].endswith(word.value) ) - for word in PROTECTED_TABLE_KEYWORDS + for word in ProtectedTableKeywords ): view_table_list.remove(view_table) @@ -260,7 +260,7 @@ def clean_study( ) if stats_clean: drop_query = get_drop_view_table( - f"{drop_prefix}{PROTECTED_TABLES.STATISTICS.value}", "TABLE" + f"{drop_prefix}{ProtectedTables.STATISTICS.value}", "TABLE" ) cursor.execute(drop_query) @@ -335,7 +335,7 @@ def _load_and_execute_builder( # execute, since the subclass would otherwise hang around. table_builder_class = table_builder_subclasses[0] table_builder = table_builder_class() - table_builder.execute_queries(cursor, schema, verbose, drop_table) + table_builder.execute_queries(cursor, schema, verbose, drop_table=drop_table) # After running the executor code, we'll remove # it so it doesn't interfere with the next python module to @@ -431,7 +431,7 @@ def run_statistics_builders( ) insert_query = get_insert_into_query( - f"{self.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + f"{self.get_study_prefix()}__{ProtectedTables.STATISTICS.value}", [ "study_name", "library_version", @@ -530,7 +530,7 @@ def _execute_build_queries( ) if any( f" {self.get_study_prefix()}__{word.value}_" in create_line - for word in PROTECTED_TABLE_KEYWORDS + for word in ProtectedTableKeywords ): self._query_error( query, @@ -538,7 +538,7 @@ def _execute_build_queries( "immediately after the study prefix. Please rename this table so " "that is does not begin with one of these special words " "immediately after the double undescore.\n" - f"Reserved words: {str(word.value for word in PROTECTED_TABLE_KEYWORDS)}", + f"Reserved words: {str(word.value for word in ProtectedTableKeywords)}", ) if create_line.count("__") > 1: self._query_error( diff --git a/tests/test_psm_templates.py b/tests/test_psm_templates.py index 5f1d2550..cb6d0c9e 100644 --- a/tests/test_psm_templates.py +++ b/tests/test_psm_templates.py @@ -53,7 +53,8 @@ def test_get_distinct_ids( @pytest.mark.parametrize( - "target,pos_source,neg_source,table_suffix,primary_ref,dep_var,join_cols_by_table,count_ref,count_table,expected,raises", + "target,pos_source,neg_source,table_suffix,primary_ref,dep_var," + "join_cols_by_table,count_ref,count_table,expected,raises", [ ( "target", diff --git a/tests/test_study_parser.py b/tests/test_study_parser.py index 5b6d6f3b..a4e7d291 100644 --- a/tests/test_study_parser.py +++ b/tests/test_study_parser.py @@ -8,7 +8,7 @@ import pytest -from cumulus_library.enums import PROTECTED_TABLE_KEYWORDS, PROTECTED_TABLES +from cumulus_library.enums import ProtectedTableKeywords, ProtectedTables from cumulus_library.study_parser import StudyManifestParser, StudyManifestParsingError from tests.test_data.parser_mock_data import get_mock_toml, mock_manifests @@ -104,7 +104,7 @@ def test_manifest_data(manifest_key, raises): ) def test_clean_study(mock_db, schema, verbose, prefix, confirm, stats, target, raises): with raises: - protected_strs = [x.value for x in PROTECTED_TABLE_KEYWORDS] + protected_strs = [x.value for x in ProtectedTableKeywords] with mock.patch.object(builtins, "input", lambda _: confirm): parser = StudyManifestParser("./tests/test_data/study_valid/") parser.run_protected_table_builder(mock_db.cursor(), schema) @@ -113,7 +113,7 @@ def test_clean_study(mock_db, schema, verbose, prefix, confirm, stats, target, r # is very slow and we're trying a lot of conditions mock_db.cursor().execute( f"CREATE TABLE {parser.get_study_prefix()}__" - f"{PROTECTED_TABLES.STATISTICS.value} " + f"{ProtectedTables.STATISTICS.value} " "AS SELECT 'study_valid' as study_name, " "'study_valid__123' AS table_name" ) @@ -138,16 +138,16 @@ def test_clean_study(mock_db, schema, verbose, prefix, confirm, stats, target, r else: assert (target,) not in remaining_tables assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.TRANSACTIONS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.TRANSACTIONS.value}", ) in remaining_tables if stats: assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.STATISTICS.value}", ) not in remaining_tables assert ("study_valid__123",) not in remaining_tables else: assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.STATISTICS.value}", ) in remaining_tables assert ("study_valid__123",) in remaining_tables @@ -168,15 +168,15 @@ def test_run_protected_table_builder(mock_db, study_path, stats): .fetchall() ) assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.TRANSACTIONS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.TRANSACTIONS.value}", ) in tables if stats: assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.STATISTICS.value}", ) in tables else: assert ( - f"{parser.get_study_prefix()}__{PROTECTED_TABLES.STATISTICS.value}", + f"{parser.get_study_prefix()}__{ProtectedTables.STATISTICS.value}", ) not in tables From 93afe62be4a85f83445f48326c5cfa0c5ff45b9a Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Thu, 21 Dec 2023 12:53:11 -0500 Subject: [PATCH 06/13] self review pass --- cumulus_library/cli.py | 15 ++++++++++++++- cumulus_library/statistics/psm.py | 14 ++------------ cumulus_library/study_parser.py | 1 - docs/statistics/propensity-score-matching.md | 4 ++-- pyproject.toml | 1 - tests/test_data/psm/psm_config.toml | 4 ++-- tests/test_data/psm/psm_config_no_optional.toml | 4 ++-- tests/test_psm.py | 16 ++++++++-------- 8 files changed, 30 insertions(+), 29 deletions(-) diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 7e5297a1..3750fa1d 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -39,6 +39,7 @@ def __init__(self, db: DatabaseBackend, data_path: str): self.schema_name = db.schema_name def update_transactions(self, prefix: str, status: str): + """Adds a record to a study's transactions table""" self.cursor.execute( get_insert_into_query( f"{prefix}__{ProtectedTables.TRANSACTIONS.value}", @@ -69,7 +70,11 @@ def clean_study( this can be useful for cleaning up tables if a study prefix is changed for some reason. - :param target: The study prefix to use for IDing tables to remove""" + :param target: The study prefix to use for IDing tables to remove + :param study_dict: The dictionary of available study targets + :param stats_clean: If true, removes previous stats runs + :keyword prefix: If True, does a search by string prefix in place of study name + """ if targets is None or targets == ["all"]: sys.exit( "Explicit targets for cleaning not provided. " @@ -103,6 +108,8 @@ def clean_and_build_study( """Recreates study views/tables :param target: A PosixPath to the study directory + :param stats_build: if True, forces creation of new stats tables + :keyword continue_from: Restart a run from a specific sql file (for dev only) """ studyparser = StudyManifestParser(target, self.data_path) try: @@ -125,6 +132,7 @@ def clean_and_build_study( ) else: self.update_transactions(studyparser.get_study_prefix(), "resumed") + studyparser.build_study(self.cursor, self.verbose, continue_from) studyparser.run_counts_builders( self.cursor, self.schema_name, verbose=self.verbose @@ -136,7 +144,10 @@ def clean_and_build_study( stats_build=stats_build, ) self.update_transactions(studyparser.get_study_prefix(), "finished") + except SystemExit as e: + # This should be thrown prior to any database connections, so + # skipping logging raise e except Exception as e: self.update_transactions(studyparser.get_study_prefix(), "error") @@ -148,6 +159,7 @@ def run_single_table_builder( """Runs a single table builder :param target: A PosixPath to the study directory + :param table_builder_name: a builder file referenced in the study's manifest """ studyparser = StudyManifestParser(target) studyparser.run_single_table_builder( @@ -161,6 +173,7 @@ def clean_and_build_all(self, study_dict: Dict, stats_build: bool) -> None: since 99% of the time you don't need a live copy in the database. :param study_dict: A dict of PosixPaths + :param stats_build: if True, regen stats tables """ study_dict = dict(study_dict) study_dict.pop("template") diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index c3a71e06..35f45c41 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -120,16 +120,9 @@ def _get_sampled_ids( :param is_positive: defines the value to be used for your filtering column """ df = cursor.execute(query).as_pandas() + df = df.sort_values(by=[self.config.primary_ref]) df = ( - df.sort_values(by=[self.config.primary_ref]) - # .reset_index() - # .drop("index", axis=1) - ) - - df = ( - # TODO: flip polarity of replace kwarg after increasing the size of the - # unit testing data - df.sample(n=sample_size, random_state=self.config.seed, replace=True) + df.sample(n=sample_size, random_state=self.config.seed, replace=False) .sort_values(by=[self.config.primary_ref]) .reset_index() .drop("index", axis=1) @@ -308,7 +301,6 @@ def generate_psm_analysis( encoded_df = pandas.get_dummies(df[column]) df = pandas.concat([df, encoded_df], axis=1) df = df.drop(column, axis=1) - df = df.reset_index() try: psm = PsmPy( df, @@ -323,10 +315,8 @@ def generate_psm_analysis( warnings.simplefilter("ignore", category=UserWarning) # This function populates the psm.predicted_data element, which is required # for things like the knn_matched() function call - # TODO: create graph from this data psm.logistic_ps(balance=True) # This function populates the psm.df_matched element - # TODO: create graph from this data psm.knn_matched( matcher="propensity_logit", replacement=False, diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index 6e6a2fbf..f6f54fca 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -403,7 +403,6 @@ def run_statistics_builders( :param schema: The name of the schema to write tables to :keyword verbose: toggle from progress bar to query output :keyword stats_build: If true, will run statistical sampling & table generation - :keyword data_path: A path to where stats output artifacts should be stored """ if not stats_build: return diff --git a/docs/statistics/propensity-score-matching.md b/docs/statistics/propensity-score-matching.md index c9774a60..b5ed1b05 100644 --- a/docs/statistics/propensity-score-matching.md +++ b/docs/statistics/propensity-score-matching.md @@ -63,8 +63,8 @@ pos_source_table = "study__diagnosis_cohort" # neg_source_table should be the primary table your positive source was built from, # i.e. it should contain all members that weren't identified as part of your cohort. -# It should be one of the base FHIR resource tables -neg_source_table = "study__condition" +# It should usually be one of the core FHIR resource tables. +neg_source_table = "core__condition" # target_table should be the name of the table you're storing your PSM cohort in. It # should be prefixed by 'studyname__' diff --git a/pyproject.toml b/pyproject.toml index a092795d..f938b17d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,6 @@ dev = [ ] test = [ "freezegun", - "jmespath", "pytest", "requests-mock", ] diff --git a/tests/test_data/psm/psm_config.toml b/tests/test_data/psm/psm_config.toml index 47c81f44..82d4b2da 100644 --- a/tests/test_data/psm/psm_config.toml +++ b/tests/test_data/psm/psm_config.toml @@ -7,8 +7,8 @@ primary_ref = 'encounter_ref' count_ref = 'subject_ref' count_table = 'core__condition' dependent_variable = "example_diagnosis" -pos_sample_size = 50 -neg_sample_size = 500 +pos_sample_size = 20 +neg_sample_size = 100 seed = 1234567890 [join_cols_by_table.core__encounter] join_id = "encounter_ref" diff --git a/tests/test_data/psm/psm_config_no_optional.toml b/tests/test_data/psm/psm_config_no_optional.toml index 36a7b795..81208359 100644 --- a/tests/test_data/psm/psm_config_no_optional.toml +++ b/tests/test_data/psm/psm_config_no_optional.toml @@ -5,6 +5,6 @@ neg_source_table = "core__condition" target_table = "psm_test__psm_encounter_covariate" primary_ref = 'encounter_ref' dependent_variable = "example_diagnosis" -pos_sample_size = 50 -neg_sample_size = 500 +pos_sample_size = 20 +neg_sample_size = 100 seed = 1234567890 diff --git a/tests/test_psm.py b/tests/test_psm.py index 897e5565..9429be85 100644 --- a/tests/test_psm.py +++ b/tests/test_psm.py @@ -17,10 +17,10 @@ [ ( "psm_config.toml", - 52, - 266, + 28, + 129, { - "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be10", + "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be0", "example_diagnosis": "1", "instance_count": 1, "gender": "female", @@ -28,7 +28,7 @@ "code": "33737001", }, { - "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19", + "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f17", "example_diagnosis": "0", "instance_count": 1, "gender": "female", @@ -38,15 +38,15 @@ ), ( "psm_config_no_optional.toml", - 52, - 266, + 28, + 129, { - "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be10", + "encounter_ref": "Encounter/03e34b19-2889-b828-792d-2a83400c55be0", "example_diagnosis": "1", "code": "33737001", }, { - "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f19", + "encounter_ref": "Encounter/ed151e04-3dd6-8cb7-a3e5-777c8a8667f17", "example_diagnosis": "0", "code": "195662009", }, From 78eddb9ba23002fdd90f78892d64167afd9b4b77 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Thu, 21 Dec 2023 13:07:17 -0500 Subject: [PATCH 07/13] docs tweak --- cumulus_library/statistics/psm.py | 4 ++-- docs/statistics/propensity-score-matching.md | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index 35f45c41..65e2b70d 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -86,10 +86,10 @@ def __init__(self, toml_config_path: str, data_path: PosixPath): seed=toml_config.get("seed", 123), ) except KeyError: - # TODO: add link to docsite when you have network access sys.exit( f"PSM configuration at {toml_config_path} contains missing/invalid keys." - "Check the PSM documentation for an example config with more details" + "Check the PSM documentation for an example config with more details:\n" + "https://docs.smarthealthit.org/cumulus/library/statistics/propensity-score-matching.html" ) def _get_symptoms_dict(self, path: str) -> dict: diff --git a/docs/statistics/propensity-score-matching.md b/docs/statistics/propensity-score-matching.md index b5ed1b05..fd940cf8 100644 --- a/docs/statistics/propensity-score-matching.md +++ b/docs/statistics/propensity-score-matching.md @@ -1,6 +1,7 @@ --- title: Propensity Score Matching parent: Statistics +grand_parent: Library nav_order: 1 # audience: clinical researchers, IRB reviewers # type: reference From 92d4c96d39f211a560987b74f7339e8827ef1c9e Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 12:08:06 -0500 Subject: [PATCH 08/13] PR feedback --- cumulus_library/cli.py | 26 ++-- cumulus_library/databases.py | 6 +- cumulus_library/errors.py | 18 +-- cumulus_library/helper.py | 14 ++ cumulus_library/protected_table_builder.py | 25 ++-- cumulus_library/statistics/psm.py | 17 ++- cumulus_library/study_parser.py | 152 ++++++++++++--------- cumulus_library/template_sql/templates.py | 7 +- tests/conftest.py | 4 +- tests/test_core.py | 3 +- tests/test_study_parser.py | 31 +++-- 11 files changed, 179 insertions(+), 124 deletions(-) diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 3750fa1d..1094e6bb 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -1,19 +1,19 @@ #!/usr/bin/env python3 """Utility for building/retrieving data views in AWS Athena""" +import datetime import json import os import sys import sysconfig -from datetime import datetime from pathlib import Path, PosixPath from typing import Dict, List, Optional from rich.console import Console from rich.table import Table -from cumulus_library import __version__ +from cumulus_library import __version__, errors, helper from cumulus_library.cli_parser import get_parser from cumulus_library.databases import ( DatabaseBackend, @@ -49,7 +49,7 @@ def update_transactions(self, prefix: str, status: str): prefix, __version__, status, - datetime.now().replace(microsecond=0).isoformat(), + helper.get_utc_date(), ] ], ) @@ -61,6 +61,7 @@ def clean_study( self, targets: List[str], study_dict: Dict, + *, stats_clean: bool, prefix: bool = False, ) -> None: @@ -102,6 +103,7 @@ def clean_study( def clean_and_build_study( self, target: PosixPath, + *, stats_build: bool, continue_from: str = None, ) -> None: @@ -145,7 +147,7 @@ def clean_and_build_study( ) self.update_transactions(studyparser.get_study_prefix(), "finished") - except SystemExit as e: + except errors.StudyManifestFilesystemError as e: # This should be thrown prior to any database connections, so # skipping logging raise e @@ -178,10 +180,12 @@ def clean_and_build_all(self, study_dict: Dict, stats_build: bool) -> None: study_dict = dict(study_dict) study_dict.pop("template") for precursor_study in ["vocab", "core"]: - self.clean_and_build_study(study_dict[precursor_study], stats_build) + self.clean_and_build_study( + study_dict[precursor_study], stats_build=stats_build + ) study_dict.pop(precursor_study) for key in study_dict: - self.clean_and_build_study(study_dict[key], stats_build) + self.clean_and_build_study(study_dict[key], stats_build=stats_build) ### Data exporters def export_study(self, target: PosixPath, data_path: PosixPath) -> None: @@ -275,10 +279,10 @@ def run_cli(args: Dict): elif args["action"] == "upload": upload_files(args) - # all other actions require connecting to AWS + # all other actions require connecting to the database else: db_backend = create_db_backend(args) - builder = StudyBuilder(db_backend, data_path=args.get("data_path", None)) + builder = StudyBuilder(db_backend, data_path=args.get("data_path")) if args["verbose"]: builder.verbose = True print("Testing connection to database...") @@ -299,8 +303,8 @@ def run_cli(args: Dict): builder.clean_study( args["target"], study_dict, - args["stats_clean"], - args["prefix"], + stats_clean=args["stats_clean"], + prefix=args["prefix"], ) elif args["action"] == "build": if "all" in args["target"]: @@ -314,7 +318,7 @@ def run_cli(args: Dict): else: builder.clean_and_build_study( study_dict[target], - args["stats_build"], + stats_build=args["stats_build"], continue_from=args["continue_from"], ) diff --git a/cumulus_library/databases.py b/cumulus_library/databases.py index 8e77dcad..75b2c489 100644 --- a/cumulus_library/databases.py +++ b/cumulus_library/databases.py @@ -12,7 +12,7 @@ import datetime import json import os -import warnings +import sys from pathlib import Path from typing import Optional, Protocol, Union @@ -265,9 +265,7 @@ def create_db_backend(args: dict[str, str]) -> DatabaseBackend: database, ) if load_ndjson_dir: - warnings.warn( - "Loading an ndjson dir is not supported with --db-type=athena." - ) + sys.exit("Loading an ndjson dir is not supported with --db-type=athena.") else: raise ValueError(f"Unexpected --db-type value '{db_type}'") diff --git a/cumulus_library/errors.py b/cumulus_library/errors.py index 1bcc96f4..da4aab31 100644 --- a/cumulus_library/errors.py +++ b/cumulus_library/errors.py @@ -3,13 +3,7 @@ class CumulusLibraryError(Exception): """ - Package level error - """ - - -class CumulusLibrarySchemaError(Exception): - """ - Package level error + Generic package level error """ @@ -17,5 +11,13 @@ class CountsBuilderError(Exception): """Basic error for CountsBuilder""" +class StudyManifestFilesystemError(Exception): + """Errors related to files on disk in StudyManifestParser""" + + class StudyManifestParsingError(Exception): - """Basic error for StudyManifestParser""" + """Errors related to manifest parsing in StudyManifestParser""" + + +class StudyManifestQueryError(Exception): + """Errors related to data queries from StudyManifestParser""" diff --git a/cumulus_library/helper.py b/cumulus_library/helper.py index 07138f71..f9a54ed3 100644 --- a/cumulus_library/helper.py +++ b/cumulus_library/helper.py @@ -1,4 +1,5 @@ """ Collection of small commonly used utility functions """ +import datetime import os import json from typing import List @@ -66,3 +67,16 @@ def get_progress_bar(**kwargs) -> progress.Progress: progress.TimeRemainingColumn(elapsed_when_finished=True), **kwargs, ) + + +def get_utc_date() -> datetime.datetime: + return ( + datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0).isoformat() + ) + + +def get_tablename_safe_iso_timestamp() -> str: + """formats a timestamp to remove sql unallowed characters in table names""" + iso_timestamp = get_utc_date() + safe_timestamp = iso_timestamp.replace(":", "_").replace("-", "_").replace("+", "_") + return safe_timestamp diff --git a/cumulus_library/protected_table_builder.py b/cumulus_library/protected_table_builder.py index b047920e..330f1a96 100644 --- a/cumulus_library/protected_table_builder.py +++ b/cumulus_library/protected_table_builder.py @@ -6,6 +6,9 @@ ) TRANSACTIONS_COLS = ["study_name", "library_version", "status", "event_time"] +TRANSACTION_COLS_TYPES = ["varchar", "varchar", "varchar", "timestamp"] +# while it may seem redundant, study_name and view_name are included as a column for +# ease of constructing a view of multiple transaction tables STATISTICS_COLS = [ "study_name", "library_version", @@ -14,6 +17,14 @@ "view_name", "created_on", ] +STATISTICS_COLS_TYPES = [ + "varchar", + "varchar", + "varchar", + "varchar", + "varchar", + "timestamp", +] class ProtectedTableBuilder(BaseTableBuilder): @@ -28,10 +39,8 @@ def prepare_queries( get_ctas_empty_query( schema, f"{study_name}__{ProtectedTables.TRANSACTIONS.value}", - # while it may seem redundant, study name is included for ease - # of constructing a view of multiple transaction tables TRANSACTIONS_COLS, - ["varchar", "varchar", "varchar", "timestamp"], + TRANSACTION_COLS_TYPES, ) ) if study_stats: @@ -39,15 +48,7 @@ def prepare_queries( get_ctas_empty_query( schema, f"{study_name}__{ProtectedTables.STATISTICS.value}", - # same redundancy note about study_name, and also view_name, applies here STATISTICS_COLS, - [ - "varchar", - "varchar", - "varchar", - "varchar", - "varchar", - "timestamp", - ], + STATISTICS_COLS_TYPES, ) ) diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index 65e2b70d..b094352f 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -37,6 +37,11 @@ class PsmConfig: These values should be read in from a toml configuration file. See docs/statistics/propensity-score-matching.md for an example with details about the expected values for these fields. + + A word of caution about sampling: the assumptions around PSM analysis require + that any sampling should not use replacement, so do not turn on panda's dataframe + replacement. This will mean that very small population sizes (i.e. < 20ish) + may cause errors to be generated. """ classification_json: str @@ -122,7 +127,7 @@ def _get_sampled_ids( df = cursor.execute(query).as_pandas() df = df.sort_values(by=[self.config.primary_ref]) df = ( - df.sample(n=sample_size, random_state=self.config.seed, replace=False) + df.sample(n=sample_size, random_state=self.config.seed) .sort_values(by=[self.config.primary_ref]) .reset_index() .drop("index", axis=1) @@ -181,7 +186,6 @@ def _create_covariate_table( count_table=self.config.count_table, ) self.queries.append(dataset_query) - print(dataset_query) def psm_plot_match( self, @@ -261,10 +265,11 @@ def psm_effect_size_plot( def generate_psm_analysis( self, cursor: DatabaseCursor, schema: str, table_suffix: str ): + stats_table = f"{self.config.target_table}_{table_suffix}" """Runs PSM statistics on generated tables""" cursor.execute( f"""CREATE OR REPLACE VIEW {self.config.target_table} - AS SELECT * FROM {self.config.target_table}_{table_suffix}""" + AS SELECT * FROM {stats_table}""" ) df = cursor.execute(f"SELECT * FROM {self.config.target_table}").as_pandas() symptoms_dict = self._get_symptoms_dict(self.config.classification_json) @@ -327,14 +332,12 @@ def generate_psm_analysis( self.psm_plot_match( psm, save=True, - filename=self.data_path - / f"{self.config.target_table}_{table_suffix}_propensity_match.png", + filename=self.data_path / f"{stats_table}_propensity_match.png", ) self.psm_effect_size_plot( psm, save=True, - filename=self.data_path - / f"{self.config.target_table}_{table_suffix}_effect_size.png", + filename=self.data_path / f"{stats_table}_effect_size.png", ) except ZeroDivisionError: sys.exit( diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index f6f54fca..a81f9d95 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -1,9 +1,9 @@ """ Contains classes for loading study data based on manifest.toml files """ +import datetime import inspect import importlib.util import sys -from datetime import datetime from pathlib import Path, PosixPath from typing import List, Optional @@ -11,17 +11,11 @@ from rich.progress import Progress, TaskID, track -from cumulus_library import __version__ +from cumulus_library import __version__, helper +from cumulus_library import errors from cumulus_library.base_table_builder import BaseTableBuilder from cumulus_library.databases import DatabaseBackend, DatabaseCursor from cumulus_library.enums import ProtectedTableKeywords, ProtectedTables -from cumulus_library.errors import StudyManifestParsingError -from cumulus_library.helper import ( - query_console_output, - load_text, - parse_sql, - get_progress_bar, -) from cumulus_library.protected_table_builder import ProtectedTableBuilder from cumulus_library.statistics.psm import PsmBuilder from cumulus_library.template_sql.templates import ( @@ -75,13 +69,13 @@ def load_study_manifest(self, study_path: Path) -> None: if not config.get("study_prefix") or not isinstance( config["study_prefix"], str ): - raise StudyManifestParsingError( + raise errors.StudyManifestParsingError( f"Invalid prefix in manifest at {study_path}" ) self._study_config = config self._study_path = study_path except FileNotFoundError: - raise StudyManifestParsingError( # pylint: disable=raise-missing-from + raise errors.StudyManifestFilesystemError( # pylint: disable=raise-missing-from f"Missing or invalid manifest found at {study_path}" ) @@ -105,7 +99,9 @@ def get_sql_file_list(self, continue_from: str = None) -> Optional[StrList]: sql_files = sql_files[pos:] break else: - sys.exit(f"No tables matching '{continue_from}' found") + raise errors.StudyManifestParsingError( + f"No tables matching '{continue_from}' found" + ) return sql_files def get_table_builder_file_list(self) -> Optional[StrList]: @@ -141,26 +137,70 @@ def get_export_table_list(self) -> Optional[StrList]: export_table_list = export_config.get("export_list", []) for table in export_table_list: if not table.startswith(f"{self.get_study_prefix()}__"): - raise StudyManifestParsingError( + raise errors.StudyManifestParsingError( f"{table} in export list does not start with prefix " f"{self.get_study_prefix()}__ - check your manifest file." ) return export_table_list - def reset_data_dir(self) -> None: + def reset_counts_exports(self) -> None: """ Removes exports associated with this study from the ../data_export directory. """ - print(self.data_path) - print(type(self.data_path)) path = Path(f"{self.data_path}/{self.get_study_prefix()}") if path.exists(): - # we're just going to remove the count files - exports related to stats - # that aren't uploaded to the aggregator are left alone. + # we're just going to remove the count exports - stats exports in + # subdirectories are left alone by this call for file in path.glob("*.*"): file.unlink() # SQL related functions + + def get_unprotected_stats_view_table( + self, + cursor: DatabaseCursor, + query: str, + artifact_type: str, + drop_prefix: str, + display_prefix: str, + stats_clean: bool, + ): + """Gets all items from the database by type, less any protected items + + :param cursor: An object of type DatabaseCursor + :param query: A query to get the raw list of items from the db + :param artifact_type: either 'table' or 'view' + :param drop_prefix: The prefix requested to drop + :param display_prefix: The expected study prefix + :param stats_clean: A boolean indicating if stats tables are being cleaned + + :returns: a list of study tables to drop + """ + unprotected_list = [] + db_contents = cursor.execute(query).fetchall() + if ( + f"{drop_prefix}{ProtectedTables.STATISTICS.value}", + ) in db_contents and not stats_clean: + protected_list = cursor.execute( + f"""SELECT {artifact_type.lower()}_name + FROM {drop_prefix}{ProtectedTables.STATISTICS.value} + WHERE study_name = '{display_prefix}'""" + ).fetchall() + for protected_tuple in protected_list: + if protected_tuple in db_contents: + db_contents.remove(protected_tuple) + for db_row_tuple in db_contents: + # this check handles athena reporting views as also being tables, + # so we don't waste time dropping things that don't exist + if artifact_type == "TABLE": + if not any( + db_row_tuple[0] in iter_q_and_t for iter_q_and_t in unprotected_list + ): + unprotected_list.append([db_row_tuple[0], artifact_type]) + else: + unprotected_list.append([db_row_tuple[0], artifact_type]) + return unprotected_list + def clean_study( self, cursor: DatabaseCursor, @@ -171,7 +211,7 @@ def clean_study( ) -> List: """Removes tables beginning with the study prefix from the database schema - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema_name: The name of the schema containing the study tables :verbose: toggle from progress bar to query output, optional :returns: list of dropped tables (for unit testing only) @@ -196,32 +236,15 @@ def clean_study( view_sql = get_show_views(schema_name, drop_prefix) table_sql = get_show_tables(schema_name, drop_prefix) - view_table_list = [] for query_and_type in [[view_sql, "VIEW"], [table_sql, "TABLE"]]: - tuple_list = cursor.execute(query_and_type[0]).fetchall() - if ( - f"{drop_prefix}{ProtectedTables.STATISTICS.value}", - ) in tuple_list and not stats_clean: - protected_list = cursor.execute( - f"""SELECT {(query_and_type[1]).lower()}_name - FROM {drop_prefix}{ProtectedTables.STATISTICS.value} - WHERE study_name = '{display_prefix}'""" - ).fetchall() - print(protected_list) - for protected_tuple in protected_list: - if protected_tuple in tuple_list: - tuple_list.remove(protected_tuple) - for db_row_tuple in tuple_list: - # this check handles athena reporting views as also being tables, - # so we don't waste time dropping things that don't exist - if query_and_type[1] == "TABLE": - if not any( - db_row_tuple[0] in iter_q_and_t - for iter_q_and_t in view_table_list - ): - view_table_list.append([db_row_tuple[0], query_and_type[1]]) - else: - view_table_list.append([db_row_tuple[0], query_and_type[1]]) + view_table_list = self.get_unprotected_stats_view_table( + cursor, + query_and_type[0], + query_and_type[1], + drop_prefix, + display_prefix, + stats_clean, + ) if not view_table_list: return view_table_list @@ -245,7 +268,7 @@ def clean_study( if confirm.lower() not in ("y", "yes"): sys.exit("Table cleaning aborted") # We want to only show a progress bar if we are :not: printing SQL lines - with get_progress_bar(disable=verbose) as progress: + with helper.get_progress_bar(disable=verbose) as progress: task = progress.add_task( f"Removing {display_prefix} study artifacts...", total=len(view_table_list), @@ -258,6 +281,8 @@ def clean_study( progress, task, ) + # if we're doing a stats clean, we'll also remove the table containing the + # list of protected tables if stats_clean: drop_query = get_drop_view_table( f"{drop_prefix}{ProtectedTables.STATISTICS.value}", "TABLE" @@ -276,7 +301,7 @@ def _execute_drop_queries( ) -> None: """Handler for executing drop view/table queries and displaying console output. - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param verbose: toggle from progress bar to query output :param view_table_list: a list of views and tables beginning with the study prefix :param progress: a rich progress bar renderer @@ -287,7 +312,7 @@ def _execute_drop_queries( name=view_table[0], view_or_table=view_table[1] ) cursor.execute(drop_view_table) - query_console_output(verbose, drop_view_table, progress, task) + helper.query_console_output(verbose, drop_view_table, progress, task) def _load_and_execute_builder( self, filename, cursor, schema, verbose, drop_table=False @@ -320,7 +345,7 @@ def _load_and_execute_builder( table_builder_subclasses.append(cls_obj) if len(table_builder_subclasses) == 0: - raise StudyManifestParsingError( + raise errors.StudyManifestParsingError( f"Error loading {self._study_path}{filename}\n" "Custom builders must extend the BaseTableBuilder class." ) @@ -348,7 +373,7 @@ def run_protected_table_builder( ) -> None: """Creates protected tables for persisting selected data across runs - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema: The name of the schema to write tables to :param verbose: toggle from progress bar to query output """ @@ -366,7 +391,7 @@ def run_table_builder( ) -> None: """Loads modules from a manifest and executes code via BaseTableBuilder - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema: The name of the schema to write tables to :param verbose: toggle from progress bar to query output """ @@ -383,7 +408,7 @@ def run_counts_builders( given dataset, where other statistical methods may use sampling techniques or adjustable input parameters that may need to be preserved for later review. - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema: The name of the schema to write tables to :param verbose: toggle from progress bar to query output """ @@ -399,7 +424,7 @@ def run_statistics_builders( ) -> None: """Loads statistics modules from toml definitions and executes - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema: The name of the schema to write tables to :keyword verbose: toggle from progress bar to query output :keyword stats_build: If true, will run statistical sampling & table generation @@ -410,8 +435,7 @@ def run_statistics_builders( # This open is a bit redundant with the open inside of the PSM builder, # but we're letting it slide so that builders function similarly # across the board - iso_timestamp = datetime.now().replace(microsecond=0).isoformat() - safe_timestamp = iso_timestamp.replace(":", "_").replace("-", "_") + safe_timestamp = helper.get_tablename_safe_iso_timestamp() toml_path = Path(f"{self._study_path}/{file}") with open(toml_path, encoding="UTF-8") as file: config = toml.load(file) @@ -422,7 +446,7 @@ def run_statistics_builders( toml_path, self.data_path / f"{self.get_study_prefix()}/psm" ) else: - raise StudyManifestParsingError( + raise errors.StudyManifestParsingError( f"{toml_path} references an invalid statistics type {config_type}." ) builder.execute_queries( @@ -446,7 +470,7 @@ def run_statistics_builders( config_type, f"{target_table}_{safe_timestamp}", target_table, - iso_timestamp, + helper.get_utc_date(), ] ], ) @@ -465,19 +489,21 @@ def build_study( ) -> List: """Creates tables in the schema by iterating through the sql_config.file_names - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param schema: The name of the schema to write tables to :param verbose: toggle from progress bar to query output, optional :returns: loaded queries (for unit testing only) """ queries = [] for file in self.get_sql_file_list(continue_from): - for query in parse_sql(load_text(f"{self._study_path}/{file}")): + for query in helper.parse_sql( + helper.load_text(f"{self._study_path}/{file}") + ): queries.append([query, file]) if len(queries) == 0: return [] # We want to only show a progress bar if we are :not: printing SQL lines - with get_progress_bar(disable=verbose) as progress: + with helper.get_progress_bar(disable=verbose) as progress: task = progress.add_task( f"Creating {self.get_study_prefix()} study in db...", total=len(queries), @@ -500,7 +526,7 @@ def _query_error(self, query_and_filename: List, exit_message: str) -> None: print("--------", file=sys.stderr) print(query_and_filename[0], file=sys.stderr) print("--------", file=sys.stderr) - sys.exit(exit_message) + raise errors.StudyManifestQueryError(exit_message) def _execute_build_queries( self, @@ -512,7 +538,7 @@ def _execute_build_queries( ) -> None: """Handler for executing create table queries and displaying console output. - :param cursor: A PEP-249 compatible cursor object + :param cursor: A DatabaseCursor object :param verbose: toggle from progress bar to query output :param queries: a list of queries read from files in sql_config.file_names :param progress: a rich progress bar renderer @@ -555,7 +581,7 @@ def _execute_build_queries( ) try: cursor.execute(query[0]) - query_console_output(verbose, query[0], progress, task) + helper.query_console_output(verbose, query[0], progress, task) except Exception as e: # pylint: disable=broad-exception-caught self._query_error( query, @@ -572,7 +598,7 @@ def export_study(self, db: DatabaseBackend, data_path: PosixPath) -> List: :param db: A database backend :returns: list of executed queries (for unit testing only) """ - self.reset_data_dir() + self.reset_counts_exports() queries = [] for table in track( self.get_export_table_list(), diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index 9c7f7c5b..3e5fc238 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -206,7 +206,7 @@ def get_ctas_empty_query( schema_name: str, table_name: str, table_cols: List[str], - table_cols_types: List[str] = [], + table_cols_types: List[str] = None, ) -> str: """Generates a create table as query for initializing an empty table @@ -221,9 +221,8 @@ def get_ctas_empty_query( :param table_cols_types: Allows specifying a data type per column (default: all varchar) """ path = Path(__file__).parent - if table_cols_types == []: - for col in table_cols: - table_cols_types.append("varchar") + if not table_cols_types: + table_cols_types = ["varchar"] * len(table_cols) with open(f"{path}/ctas_empty.sql.jinja") as ctas_empty: return Template(ctas_empty.read()).render( schema_name=schema_name, diff --git a/tests/conftest.py b/tests/conftest.py index 60b1511c..454343e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -173,7 +173,7 @@ def mock_db_core(tmp_path, mock_db): # pylint: disable=redefined-outer-name """Provides a DuckDatabaseBackend with the core study ran for local testing""" builder = StudyBuilder(mock_db, data_path=f"{tmp_path}/data_path") builder.clean_and_build_study( - f"{Path(__file__).parent.parent}/cumulus_library/studies/core", True + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", stats_build=True ) yield mock_db @@ -191,6 +191,6 @@ def mock_db_stats(tmp_path): ) builder = StudyBuilder(db, data_path=f"{tmp_path}/data_path") builder.clean_and_build_study( - f"{Path(__file__).parent.parent}/cumulus_library/studies/core", True + f"{Path(__file__).parent.parent}/cumulus_library/studies/core", stats_build=True ) yield db diff --git a/tests/test_core.py b/tests/test_core.py index 6a1e37a2..7d838a71 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -76,8 +76,7 @@ def test_core_count_missing_data(tmp_path, mock_db): builder = StudyBuilder(mock_db, f"{tmp_path}/data_path/") builder.clean_and_build_study( f"{Path(__file__).parent.parent}/cumulus_library/studies/core", - f"{tmp_path}/data_path/", - False, + stats_build=False, ) table_rows = cursor.execute("SELECT * FROM core__count_encounter_month").fetchall() # For regenerating data if needed diff --git a/tests/test_study_parser.py b/tests/test_study_parser.py index a4e7d291..2bcdec21 100644 --- a/tests/test_study_parser.py +++ b/tests/test_study_parser.py @@ -8,8 +8,9 @@ import pytest +from cumulus_library import errors from cumulus_library.enums import ProtectedTableKeywords, ProtectedTables -from cumulus_library.study_parser import StudyManifestParser, StudyManifestParsingError +from cumulus_library.study_parser import StudyManifestParser from tests.test_data.parser_mock_data import get_mock_toml, mock_manifests @@ -18,10 +19,13 @@ [ ("test_data/study_valid", does_not_raise()), (None, does_not_raise()), - ("test_data/study_missing_prefix", pytest.raises(StudyManifestParsingError)), - ("test_data/study_wrong_type", pytest.raises(StudyManifestParsingError)), - ("", pytest.raises(StudyManifestParsingError)), - (".", pytest.raises(StudyManifestParsingError)), + ( + "test_data/study_missing_prefix", + pytest.raises(errors.StudyManifestParsingError), + ), + ("test_data/study_wrong_type", pytest.raises(errors.StudyManifestParsingError)), + ("", pytest.raises(errors.StudyManifestFilesystemError)), + (".", pytest.raises(errors.StudyManifestFilesystemError)), ], ) def test_load_manifest(manifest_path, raises): @@ -40,7 +44,7 @@ def test_load_manifest(manifest_path, raises): ("valid_empty_arrays", does_not_raise()), ("valid_null_arrays", does_not_raise()), ("valid_only_prefix", does_not_raise()), - ("invalid_bad_export_names", pytest.raises(StudyManifestParsingError)), + ("invalid_bad_export_names", pytest.raises(errors.StudyManifestParsingError)), ("invalid_none", pytest.raises(TypeError)), ], ) @@ -205,7 +209,7 @@ def test_run_protected_table_builder(mock_db, study_path, stats): "./tests/test_data/study_python_no_subclass/", True, (), - pytest.raises(StudyManifestParsingError), + pytest.raises(errors.StudyManifestParsingError), ), ], ) @@ -246,24 +250,29 @@ def test_table_builder(mock_db, study_path, verbose, expects, raises): ("study_valid__table",), does_not_raise(), ), - ("./tests/test_data/study_wrong_prefix/", None, [], pytest.raises(SystemExit)), + ( + "./tests/test_data/study_wrong_prefix/", + None, + [], + pytest.raises(errors.StudyManifestQueryError), + ), ( "./tests/test_data/study_invalid_no_dunder/", True, (), - pytest.raises(SystemExit), + pytest.raises(errors.StudyManifestQueryError), ), ( "./tests/test_data/study_invalid_two_dunder/", True, (), - pytest.raises(SystemExit), + pytest.raises(errors.StudyManifestQueryError), ), ( "./tests/test_data/study_invalid_reserved_word/", True, (), - pytest.raises(SystemExit), + pytest.raises(errors.StudyManifestQueryError), ), ], ) From 8d79e7c4b951aad2b3c7b5287571a256b47af3a1 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 13:46:13 -0500 Subject: [PATCH 09/13] typecasts for athena --- cumulus_library/cli.py | 5 +++-- cumulus_library/helper.py | 8 +++----- cumulus_library/study_parser.py | 2 +- cumulus_library/template_sql/insert_into.sql.jinja | 4 ++++ cumulus_library/template_sql/templates.py | 10 ++++++++-- tests/test_templates.py | 12 ++++++++++++ 6 files changed, 31 insertions(+), 10 deletions(-) diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 1094e6bb..264376fc 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -49,9 +49,10 @@ def update_transactions(self, prefix: str, status: str): prefix, __version__, status, - helper.get_utc_date(), + helper.get_utc_datetime(), ] ], + {"event_time": "TIMESTAMP"}, ) ) @@ -152,7 +153,7 @@ def clean_and_build_study( # skipping logging raise e except Exception as e: - self.update_transactions(studyparser.get_study_prefix(), "error") + # self.update_transactions(studyparser.get_study_prefix(), "error") raise e def run_single_table_builder( diff --git a/cumulus_library/helper.py b/cumulus_library/helper.py index f9a54ed3..e66aa3bf 100644 --- a/cumulus_library/helper.py +++ b/cumulus_library/helper.py @@ -69,14 +69,12 @@ def get_progress_bar(**kwargs) -> progress.Progress: ) -def get_utc_date() -> datetime.datetime: - return ( - datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0).isoformat() - ) +def get_utc_datetime() -> datetime.datetime: + return datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) def get_tablename_safe_iso_timestamp() -> str: """formats a timestamp to remove sql unallowed characters in table names""" - iso_timestamp = get_utc_date() + iso_timestamp = get_utc_datetime().isoformat() safe_timestamp = iso_timestamp.replace(":", "_").replace("-", "_").replace("+", "_") return safe_timestamp diff --git a/cumulus_library/study_parser.py b/cumulus_library/study_parser.py index a81f9d95..b8604d3d 100644 --- a/cumulus_library/study_parser.py +++ b/cumulus_library/study_parser.py @@ -470,7 +470,7 @@ def run_statistics_builders( config_type, f"{target_table}_{safe_timestamp}", target_table, - helper.get_utc_date(), + helper.get_utc_datetime(), ] ], ) diff --git a/cumulus_library/template_sql/insert_into.sql.jinja b/cumulus_library/template_sql/insert_into.sql.jinja index cc058abd..e643d886 100644 --- a/cumulus_library/template_sql/insert_into.sql.jinja +++ b/cumulus_library/template_sql/insert_into.sql.jinja @@ -11,7 +11,11 @@ VALUES {%- for row in dataset %} ( {%- for field in row -%} + {%- if table_cols[loop.index0] in type_casts.keys() -%} + {{ type_casts[table_cols[loop.index0]] }} '{{ field }}' + {%- else -%} '{{ field }}' + {%- endif -%} {%- if not loop.last -%} , {%- endif -%} diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index 3e5fc238..463a91d0 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -269,7 +269,10 @@ def get_extension_denormalize_query(config: ExtensionConfig) -> str: def get_insert_into_query( - table_name: str, table_cols: List[str], dataset: List[List[str]] + table_name: str, + table_cols: List[str], + dataset: List[List[str]], + type_casts: Dict = {}, ) -> str: """Generates an insert query for adding data to an existing athena table @@ -281,7 +284,10 @@ def get_insert_into_query( path = Path(__file__).parent with open(f"{path}/insert_into.sql.jinja") as insert_into: return Template(insert_into.read()).render( - table_name=table_name, table_cols=table_cols, dataset=dataset + table_name=table_name, + table_cols=table_cols, + dataset=dataset, + type_casts=type_casts, ) diff --git a/tests/test_templates.py b/tests/test_templates.py index d5e91f20..e154f124 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -438,6 +438,18 @@ def test_insert_into_query_creation(): dataset=[["foo", "foo"], ["bar", "bar"]], ) assert query == expected + expected = """INSERT INTO test_table +("a","b") +VALUES +('foo',VARCHAR 'foo'), +('bar',VARCHAR 'bar');""" + query = get_insert_into_query( + table_name="test_table", + table_cols=["a", "b"], + dataset=[["foo", "foo"], ["bar", "bar"]], + type_casts={"b": "VARCHAR"}, + ) + assert query == expected def test_is_table_not_empty(): From befd3b610b281669a38991db65f7cdbedcffccf2 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 13:49:42 -0500 Subject: [PATCH 10/13] sqlfluff vars --- cumulus_library/.sqlfluff | 1 + 1 file changed, 1 insertion(+) diff --git a/cumulus_library/.sqlfluff b/cumulus_library/.sqlfluff index d755da80..2c807116 100644 --- a/cumulus_library/.sqlfluff +++ b/cumulus_library/.sqlfluff @@ -51,6 +51,7 @@ table_name = test_table table_suffix = 2024_01_01_11_11_11 target_col_prefix = prefix target_table = target_table +type_casts={"b": "VARCHAR"} unnests = [{"source col": "g", "table_alias": "i", "row_alias":"j"}, {"source col": "k", "table_alias": "l", "row_alias":"m"},] view_cols = ["c","d"] view_name = test_view From 63605abadf1e78c1891608c4d79ed35882dcd9cb Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 16:15:53 -0500 Subject: [PATCH 11/13] test transaction log --- cumulus_library/cli.py | 103 +++++++++--------- tests/test_cli.py | 48 +++++++- .../study_invalid_bad_query/manifest.toml | 7 ++ .../study_invalid_bad_query/test.sql | 1 + 4 files changed, 105 insertions(+), 54 deletions(-) create mode 100644 tests/test_data/study_invalid_bad_query/manifest.toml create mode 100644 tests/test_data/study_invalid_bad_query/test.sql diff --git a/cumulus_library/cli.py b/cumulus_library/cli.py index 264376fc..77b7088e 100755 --- a/cumulus_library/cli.py +++ b/cumulus_library/cli.py @@ -153,7 +153,7 @@ def clean_and_build_study( # skipping logging raise e except Exception as e: - # self.update_transactions(studyparser.get_study_prefix(), "error") + self.update_transactions(studyparser.get_study_prefix(), "error") raise e def run_single_table_builder( @@ -282,57 +282,56 @@ def run_cli(args: Dict): # all other actions require connecting to the database else: - db_backend = create_db_backend(args) - builder = StudyBuilder(db_backend, data_path=args.get("data_path")) - if args["verbose"]: - builder.verbose = True - print("Testing connection to database...") - builder.cursor.execute("SHOW DATABASES") - - study_dict = get_study_dict(args["study_dir"]) - if "prefix" not in args.keys(): - if args["target"]: - for target in args["target"]: - if target not in study_dict: - sys.exit( - f"{target} was not found in available studies: " - f"{list(study_dict.keys())}.\n\n" - "If you are trying to run a custom study, make sure " - "you include `-s path/to/study/dir` as an arugment." - ) - if args["action"] == "clean": - builder.clean_study( - args["target"], - study_dict, - stats_clean=args["stats_clean"], - prefix=args["prefix"], - ) - elif args["action"] == "build": - if "all" in args["target"]: - builder.clean_and_build_all(study_dict, args["stats_build"]) - else: - for target in args["target"]: - if args["builder"]: - builder.run_single_table_builder( - study_dict[target], args["builder"] - ) - else: - builder.clean_and_build_study( - study_dict[target], - stats_build=args["stats_build"], - continue_from=args["continue_from"], - ) - - elif args["action"] == "export": - if "all" in args["target"]: - builder.export_all(study_dict, args["data_path"]) - else: - for target in args["target"]: - builder.export_study(study_dict[target], args["data_path"]) - - db_backend.close() - # returning the builder for ease of unit testing - return builder + try: + db_backend = create_db_backend(args) + builder = StudyBuilder(db_backend, data_path=args.get("data_path")) + if args["verbose"]: + builder.verbose = True + print("Testing connection to database...") + builder.cursor.execute("SHOW DATABASES") + + study_dict = get_study_dict(args["study_dir"]) + if "prefix" not in args.keys(): + if args["target"]: + for target in args["target"]: + if target not in study_dict: + sys.exit( + f"{target} was not found in available studies: " + f"{list(study_dict.keys())}.\n\n" + "If you are trying to run a custom study, make sure " + "you include `-s path/to/study/dir` as an arugment." + ) + if args["action"] == "clean": + builder.clean_study( + args["target"], + study_dict, + stats_clean=args["stats_clean"], + prefix=args["prefix"], + ) + elif args["action"] == "build": + if "all" in args["target"]: + builder.clean_and_build_all(study_dict, args["stats_build"]) + else: + for target in args["target"]: + if args["builder"]: + builder.run_single_table_builder( + study_dict[target], args["builder"] + ) + else: + builder.clean_and_build_study( + study_dict[target], + stats_build=args["stats_build"], + continue_from=args["continue_from"], + ) + + elif args["action"] == "export": + if "all" in args["target"]: + builder.export_all(study_dict, args["data_path"]) + else: + for target in args["target"]: + builder.export_study(study_dict[target], args["data_path"]) + finally: + db_backend.close() def main(cli_args=None): diff --git a/tests/test_cli.py b/tests/test_cli.py index 0bbcd4d8..68ae24e8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,7 +13,7 @@ import requests_mock import toml -from cumulus_library import cli +from cumulus_library import cli, errors from cumulus_library.databases import DuckDatabaseBackend from tests.conftest import duckdb_args @@ -221,7 +221,7 @@ def test_clean(mock_path, tmp_path, args, expected): # pylint: disable=unused-a (["build", "-t", "vocab"], None, 3), ( [ # checking that a study is loaded from the directory of a user-defined path. - # we're also validating that the CLI accpes the statistics keyword, though + # we're also validating that the CLI accepts the statistics keyword, though "build", "-t", "study_valid", @@ -266,6 +266,50 @@ def test_cli_executes_queries(tmp_path, build_args, export_args, expected_tables assert any(export_table in x for x in csv_files) +@mock.patch.dict( + os.environ, + clear=True, +) +@pytest.mark.parametrize( + "study,finishes,raises", + [ + ("study_valid", True, does_not_raise()), + ( + "study_invalid_bad_query", + False, + pytest.raises(errors.StudyManifestQueryError), + ), + ], +) +def test_cli_transactions(tmp_path, study, finishes, raises): + with raises: + args = duckdb_args( + ["build", "-t", study, "--database", "test", "-s", "tests/test_data/"], + f"{tmp_path}", + ) + print(args[-1:]) + + args = args[:-1] + [ + f"{tmp_path}/{study}_duck.db", + ] + print(args[-1:]) + cli.main(cli_args=args) + db = DuckDatabaseBackend(f"{tmp_path}/{study}_duck.db") + print( + db.cursor() + .execute("select table_name from information_schema.tables") + .fetchall() + ) + query = ( + db.cursor().execute(f"SELECT * from study_valid__lib_transactions").fetchall() + ) + assert query[1][2] == "started" + if finishes: + assert query[2][2] == "finished" + else: + assert query[2][2] == "error" + + @mock.patch.dict( os.environ, clear=True, diff --git a/tests/test_data/study_invalid_bad_query/manifest.toml b/tests/test_data/study_invalid_bad_query/manifest.toml new file mode 100644 index 00000000..81ba3008 --- /dev/null +++ b/tests/test_data/study_invalid_bad_query/manifest.toml @@ -0,0 +1,7 @@ +study_prefix = "study_valid" + +[sql_config] +file_names = ["test.sql"] + +[export_config] +export_list = ["study_valid__table"] diff --git a/tests/test_data/study_invalid_bad_query/test.sql b/tests/test_data/study_invalid_bad_query/test.sql new file mode 100644 index 00000000..72477fc7 --- /dev/null +++ b/tests/test_data/study_invalid_bad_query/test.sql @@ -0,0 +1 @@ +this is not even close to being a valid sql query -- noqa From 28e8a801d776921bc3cae72b470739f26495f4d4 Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 16:51:01 -0500 Subject: [PATCH 12/13] replaced inline PSM queries --- cumulus_library/statistics/psm.py | 15 ++-- .../template_sql/alias_table.sql.jinja | 2 + .../template_sql/select_all.sql.jinja | 1 + cumulus_library/template_sql/templates.py | 53 +++++++------- tests/test_templates.py | 73 ++++++++++--------- 5 files changed, 75 insertions(+), 69 deletions(-) create mode 100644 cumulus_library/template_sql/alias_table.sql.jinja create mode 100644 cumulus_library/template_sql/select_all.sql.jinja diff --git a/cumulus_library/statistics/psm.py b/cumulus_library/statistics/psm.py index b094352f..3475f9c8 100644 --- a/cumulus_library/statistics/psm.py +++ b/cumulus_library/statistics/psm.py @@ -20,10 +20,8 @@ from cumulus_library.databases import DatabaseCursor from cumulus_library.base_table_builder import BaseTableBuilder -from cumulus_library.template_sql.templates import ( - get_ctas_query_from_df, - get_drop_view_table, -) +from cumulus_library.template_sql import templates + from cumulus_library.template_sql.statistics.psm_templates import ( get_distinct_ids, get_create_covariate_table, @@ -167,7 +165,7 @@ def _create_covariate_table( ) cohort = pandas.concat([pos, neg]) - ctas_query = get_ctas_query_from_df( + ctas_query = templates.get_ctas_query_from_df( schema, f"{self.config.pos_source_table}_sampled_ids_{table_suffix}", cohort, @@ -268,10 +266,11 @@ def generate_psm_analysis( stats_table = f"{self.config.target_table}_{table_suffix}" """Runs PSM statistics on generated tables""" cursor.execute( - f"""CREATE OR REPLACE VIEW {self.config.target_table} - AS SELECT * FROM {stats_table}""" + templates.get_alias_table_query(stats_table, self.config.target_table) ) - df = cursor.execute(f"SELECT * FROM {self.config.target_table}").as_pandas() + df = cursor.execute( + templates.get_select_all_query(self.config.target_table) + ).as_pandas() symptoms_dict = self._get_symptoms_dict(self.config.classification_json) for dependent_variable, codes in symptoms_dict.items(): df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0) diff --git a/cumulus_library/template_sql/alias_table.sql.jinja b/cumulus_library/template_sql/alias_table.sql.jinja new file mode 100644 index 00000000..65ff9e79 --- /dev/null +++ b/cumulus_library/template_sql/alias_table.sql.jinja @@ -0,0 +1,2 @@ +CREATE OR REPLACE VIEW {{ target }} + AS SELECT * FROM {{ source }}; \ No newline at end of file diff --git a/cumulus_library/template_sql/select_all.sql.jinja b/cumulus_library/template_sql/select_all.sql.jinja new file mode 100644 index 00000000..4f3e3e3f --- /dev/null +++ b/cumulus_library/template_sql/select_all.sql.jinja @@ -0,0 +1 @@ +SELECT * FROM {{ target }}; \ No newline at end of file diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index 463a91d0..13dbafc0 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -7,6 +7,9 @@ from pandas import DataFrame +PATH = Path(__file__).parent + + class TableView(Enum): """Convenience enum for building drop queries""" @@ -83,10 +86,15 @@ def __init__( self.is_array = is_array +def get_alias_table_query(source: str, target: str): + """Creates a 1-1 alias of a given table""" + with open(f"{PATH}/alias_table.sql.jinja") as alias_table: + return Template(alias_table.read()).render(source=source, target=target) + + def get_code_system_pairs(output_table_name: str, code_system_tables: list) -> str: """Extracts code system details as a standalone table""" - path = Path(__file__).parent - with open(f"{path}/code_system_pairs.sql.jinja") as code_system_pairs: + with open(f"{PATH}/code_system_pairs.sql.jinja") as code_system_pairs: return Template(code_system_pairs.read()).render( output_table_name=output_table_name, code_system_tables=code_system_tables ) @@ -104,14 +112,13 @@ def get_codeable_concept_denormalize_query(config: CodeableConceptConfig) -> str :param config: a CodableConeptConfig """ - path = Path(__file__).parent # If we get a None for code systems, we want one dummy value so the jinja # for loop will do a single pass. This implicitly means that we're not # filtering, so this parameter will be otherwise ignored config.code_systems = config.code_systems or ["all"] - with open(f"{path}/codeable_concept_denormalize.sql.jinja") as codable_concept: + with open(f"{PATH}/codeable_concept_denormalize.sql.jinja") as codable_concept: return Template(codable_concept.read()).render( source_table=config.source_table, source_id=config.source_id, @@ -124,8 +131,7 @@ def get_codeable_concept_denormalize_query(config: CodeableConceptConfig) -> str def get_column_datatype_query(schema_name: str, table_name: str, column_name: str): - path = Path(__file__).parent - with open(f"{path}/column_datatype.sql.jinja") as column_datatype: + with open(f"{PATH}/column_datatype.sql.jinja") as column_datatype: return Template(column_datatype.read()).render( schema_name=schema_name, table_name=table_name, @@ -136,8 +142,7 @@ def get_column_datatype_query(schema_name: str, table_name: str, column_name: st def get_core_medication_query( medication_datasources: dict, has_userselected: Optional[bool] = False ): - path = Path(__file__).parent - with open(f"{path}/core_medication.sql.jinja") as core_medication: + with open(f"{PATH}/core_medication.sql.jinja") as core_medication: return Template(core_medication.read()).render( medication_datasources=medication_datasources, has_userselected=has_userselected, @@ -153,8 +158,7 @@ def get_create_view_query( :param dataset: Array of data arrays to insert, i.e. [['1','3'],['2','4']] :param table_cols: Comma deleniated column names, i.e. ['first,second'] """ - path = Path(__file__).parent - with open(f"{path}/create_view_as.sql.jinja") as cvas: + with open(f"{PATH}/create_view_as.sql.jinja") as cvas: return Template(cvas.read()).render( view_name=view_name, dataset=dataset, @@ -177,8 +181,7 @@ def get_ctas_query( :param dataset: Array of data arrays to insert, i.e. [['1','3'],['2','4']] :param table_cols: Comma deleniated column names, i.e. ['first,second'] """ - path = Path(__file__).parent - with open(f"{path}/ctas.sql.jinja") as ctas: + with open(f"{PATH}/ctas.sql.jinja") as ctas: return Template(ctas.read()).render( schema_name=schema_name, table_name=table_name, @@ -220,10 +223,9 @@ def get_ctas_empty_query( :param table_cols: Comma deleniated column names, i.e. ['first,second'] :param table_cols_types: Allows specifying a data type per column (default: all varchar) """ - path = Path(__file__).parent if not table_cols_types: table_cols_types = ["varchar"] * len(table_cols) - with open(f"{path}/ctas_empty.sql.jinja") as ctas_empty: + with open(f"{PATH}/ctas_empty.sql.jinja") as ctas_empty: return Template(ctas_empty.read()).render( schema_name=schema_name, table_name=table_name, @@ -235,8 +237,7 @@ def get_ctas_empty_query( def get_drop_view_table(name: str, view_or_table: str) -> str: """Generates a drop table if exists query""" if view_or_table in [e.value for e in TableView]: - path = Path(__file__).parent - with open(f"{path}/drop_view_table.sql.jinja") as drop_view_table: + with open(f"{PATH}/drop_view_table.sql.jinja") as drop_view_table: return Template(drop_view_table.read()).render( view_or_table_name=name, view_or_table=view_or_table ) @@ -255,8 +256,7 @@ def get_extension_denormalize_query(config: ExtensionConfig) -> str: :param config: An instance of ExtensionConfig. """ - path = Path(__file__).parent - with open(f"{path}/extension_denormalize.sql.jinja") as extension_denormalize: + with open(f"{PATH}/extension_denormalize.sql.jinja") as extension_denormalize: return Template(extension_denormalize.read()).render( source_table=config.source_table, source_id=config.source_id, @@ -281,8 +281,7 @@ def get_insert_into_query( :param table_cols: Comma deleniated column names, i.e. ['first','second'] :param dataset: Array of data arrays to insert, i.e. [['1','3'],['2','4']] """ - path = Path(__file__).parent - with open(f"{path}/insert_into.sql.jinja") as insert_into: + with open(f"{PATH}/insert_into.sql.jinja") as insert_into: return Template(insert_into.read()).render( table_name=table_name, table_cols=table_cols, @@ -297,8 +296,7 @@ def get_is_table_not_empty_query( unnests: Optional[list[dict]] = [], conditions: Optional[list[str]] = [], ): - path = Path(__file__).parent - with open(f"{path}/is_table_not_empty.sql.jinja") as is_table_not_empty: + with open(f"{PATH}/is_table_not_empty.sql.jinja") as is_table_not_empty: return Template(is_table_not_empty.read()).render( source_table=source_table, field=field, @@ -307,6 +305,11 @@ def get_is_table_not_empty_query( ) +def get_select_all_query(target: str): + with open(f"{PATH}/select_all.sql.jinja") as select_all: + return Template(select_all.read()).render(target=target) + + def get_show_tables(schema_name: str, prefix: str) -> str: """Generates a show tables query, filtered by prefix @@ -316,8 +319,7 @@ def get_show_tables(schema_name: str, prefix: str) -> str: :param schema_name: The athena schema to query :param table_name: The prefix to filter by. Jinja template auto adds '__'. """ - path = Path(__file__).parent - with open(f"{path}/show_tables.sql.jinja") as show_tables: + with open(f"{PATH}/show_tables.sql.jinja") as show_tables: return Template(show_tables.read()).render( schema_name=schema_name, prefix=prefix ) @@ -332,8 +334,7 @@ def get_show_views(schema_name: str, prefix: str) -> str: :param schema_name: The athena schema to query :param table_name: The prefix to filter by. Jinja template auto adds '__'. """ - path = Path(__file__).parent - with open(f"{path}/show_views.sql.jinja") as show_tables: + with open(f"{PATH}/show_views.sql.jinja") as show_tables: return Template(show_tables.read()).render( schema_name=schema_name, prefix=prefix ) diff --git a/tests/test_templates.py b/tests/test_templates.py index e154f124..38bbb255 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -3,21 +3,20 @@ from pandas import DataFrame -from cumulus_library.template_sql.templates import ( - CodeableConceptConfig, - ExtensionConfig, - get_code_system_pairs, - get_codeable_concept_denormalize_query, - get_column_datatype_query, - get_core_medication_query, - get_create_view_query, - get_ctas_empty_query, - get_ctas_query, - get_ctas_query_from_df, - get_extension_denormalize_query, - get_insert_into_query, - get_is_table_not_empty_query, -) +from cumulus_library.template_sql import templates + + +def test_alias_table(): + expected = """CREATE OR REPLACE VIEW target + AS SELECT * FROM source;""" + query = templates.get_alias_table_query("source", "target") + assert query == expected + + +def test_select_all(): + expected = """SELECT * FROM source;""" + query = templates.get_select_all_query("source") + assert query == expected def test_codeable_concept_denormalize_all_creation(): @@ -52,14 +51,14 @@ def test_codeable_concept_denormalize_all_creation(): FROM union_table ); """ - config = CodeableConceptConfig( + config = templates.CodeableConceptConfig( source_table="source", source_id="id", column_name="code_col", target_table="target__concepts", is_array=True, ) - query = get_codeable_concept_denormalize_query(config) + query = templates.get_codeable_concept_denormalize_query(config) assert query == expected @@ -140,7 +139,7 @@ def test_codeable_concept_denormalize_filter_creation(): ); """ - config = CodeableConceptConfig( + config = templates.CodeableConceptConfig( source_table="source", source_id="id", column_name="code_col", @@ -152,7 +151,7 @@ def test_codeable_concept_denormalize_filter_creation(): "http://hl7.org/fhir/sid/icd-10-cm", ], ) - query = get_codeable_concept_denormalize_query(config) + query = templates.get_codeable_concept_denormalize_query(config) assert query == expected @@ -165,7 +164,7 @@ def test_get_column_datatype_query(): AND table_name = 'table_name' AND LOWER(column_name) = 'column_name'""" - query = get_column_datatype_query( + query = templates.get_column_datatype_query( schema_name="schema_name", table_name="table_name", column_name="column_name", @@ -215,7 +214,9 @@ def test_get_column_datatype_query(): ], ) def test_core_medication_query(medication_datasources, contains, omits): - query = get_core_medication_query(medication_datasources=medication_datasources) + query = templates.get_core_medication_query( + medication_datasources=medication_datasources + ) for item in contains: assert item in query for item in omits: @@ -232,7 +233,7 @@ def test_create_view_query_creation(): AS t ("a","b") );""" - query = get_create_view_query( + query = templates.get_create_view_query( view_name="test_view", dataset=[["foo", "foo"], ["bar", "bar"]], view_cols=["a", "b"], @@ -274,7 +275,7 @@ def test_create_view_query_creation(): ], ) def test_ctas_empty_query_creation(expected, schema, table, cols, types): - query = get_ctas_empty_query( + query = templates.get_ctas_empty_query( schema_name=schema, table_name=table, table_cols=cols, table_cols_types=types ) assert query == expected @@ -289,14 +290,14 @@ def test_ctas_query_creation(): ) AS t ("a","b") );""" - query = get_ctas_query( + query = templates.get_ctas_query( schema_name="test_schema", table_name="test_table", dataset=[["foo", "foo"], ["bar", "bar"]], table_cols=["a", "b"], ) assert query == expected - query = get_ctas_query_from_df( + query = templates.get_ctas_query_from_df( schema_name="test_schema", table_name="test_table", df=DataFrame({"a": ["foo", "bar"], "b": ["foo", "bar"]}), @@ -384,7 +385,7 @@ def test_extension_denormalize_creation(): ) WHERE available_priority = 1 );""" - config = ExtensionConfig( + config = templates.ExtensionConfig( "source_table", "source_id", "target_table", @@ -392,9 +393,9 @@ def test_extension_denormalize_creation(): "fhir_extension", ["omb", "text"], ) - query = get_extension_denormalize_query(config) + query = templates.get_extension_denormalize_query(config) assert query == expected - config = ExtensionConfig( + config = templates.ExtensionConfig( "source_table", "source_id", "target_table", @@ -403,7 +404,7 @@ def test_extension_denormalize_creation(): ["omb", "text"], is_array=True, ) - query = get_extension_denormalize_query(config) + query = templates.get_extension_denormalize_query(config) array_sql = """LOWER( ARRAY_JOIN( ARRAY_SORT( @@ -432,7 +433,7 @@ def test_insert_into_query_creation(): VALUES ('foo','foo'), ('bar','bar');""" - query = get_insert_into_query( + query = templates.get_insert_into_query( table_name="test_table", table_cols=["a", "b"], dataset=[["foo", "foo"], ["bar", "bar"]], @@ -443,7 +444,7 @@ def test_insert_into_query_creation(): VALUES ('foo',VARCHAR 'foo'), ('bar',VARCHAR 'bar');""" - query = get_insert_into_query( + query = templates.get_insert_into_query( table_name="test_table", table_cols=["a", "b"], dataset=[["foo", "foo"], ["bar", "bar"]], @@ -460,7 +461,9 @@ def test_is_table_not_empty(): WHERE field_name IS NOT NULL LIMIT 1;""" - query = get_is_table_not_empty_query(source_table="table_name", field="field_name") + query = templates.get_is_table_not_empty_query( + source_table="table_name", field="field_name" + ) assert query == expected expected = """SELECT @@ -472,7 +475,7 @@ def test_is_table_not_empty(): WHERE field_name IS NOT NULL LIMIT 1;""" - query = get_is_table_not_empty_query( + query = templates.get_is_table_not_empty_query( source_table="table_name", field="field_name", unnests=[ @@ -492,7 +495,7 @@ def test_is_table_not_empty(): AND field_name IS NOT NULL --noqa: LT02 LIMIT 1;""" - query = get_is_table_not_empty_query( + query = templates.get_is_table_not_empty_query( source_table="table_name", field="field_name", conditions=["field_name LIKE 's%'", "field_name IS NOT NULL"], @@ -542,7 +545,7 @@ def test_get_code_system_pairs(): ) ) AS t (table_name, column_name, code, display, system)""" - query = get_code_system_pairs( + query = templates.get_code_system_pairs( "output_table", [ { From d6345d264903f756c677ef657e16e74ba8b6341b Mon Sep 17 00:00:00 2001 From: Matt Garber Date: Tue, 26 Dec 2023 16:54:24 -0500 Subject: [PATCH 13/13] sqlfluff cleanup --- cumulus_library/template_sql/alias_table.sql.jinja | 4 ++-- cumulus_library/template_sql/select_all.sql.jinja | 2 +- cumulus_library/template_sql/templates.py | 10 ++++++---- tests/test_templates.py | 4 ++-- 4 files changed, 11 insertions(+), 9 deletions(-) diff --git a/cumulus_library/template_sql/alias_table.sql.jinja b/cumulus_library/template_sql/alias_table.sql.jinja index 65ff9e79..d121917f 100644 --- a/cumulus_library/template_sql/alias_table.sql.jinja +++ b/cumulus_library/template_sql/alias_table.sql.jinja @@ -1,2 +1,2 @@ -CREATE OR REPLACE VIEW {{ target }} - AS SELECT * FROM {{ source }}; \ No newline at end of file +CREATE OR REPLACE VIEW {{ target_table }} +AS SELECT * FROM {{ source_table }}; diff --git a/cumulus_library/template_sql/select_all.sql.jinja b/cumulus_library/template_sql/select_all.sql.jinja index 4f3e3e3f..9faef1c4 100644 --- a/cumulus_library/template_sql/select_all.sql.jinja +++ b/cumulus_library/template_sql/select_all.sql.jinja @@ -1 +1 @@ -SELECT * FROM {{ target }}; \ No newline at end of file +SELECT * FROM {{ source_table }}; diff --git a/cumulus_library/template_sql/templates.py b/cumulus_library/template_sql/templates.py index 13dbafc0..1bd013e5 100644 --- a/cumulus_library/template_sql/templates.py +++ b/cumulus_library/template_sql/templates.py @@ -86,10 +86,12 @@ def __init__( self.is_array = is_array -def get_alias_table_query(source: str, target: str): +def get_alias_table_query(source_table: str, target_table: str): """Creates a 1-1 alias of a given table""" with open(f"{PATH}/alias_table.sql.jinja") as alias_table: - return Template(alias_table.read()).render(source=source, target=target) + return Template(alias_table.read()).render( + source_table=source_table, target_table=target_table + ) def get_code_system_pairs(output_table_name: str, code_system_tables: list) -> str: @@ -305,9 +307,9 @@ def get_is_table_not_empty_query( ) -def get_select_all_query(target: str): +def get_select_all_query(source_table: str): with open(f"{PATH}/select_all.sql.jinja") as select_all: - return Template(select_all.read()).render(target=target) + return Template(select_all.read()).render(source_table=source_table) def get_show_tables(schema_name: str, prefix: str) -> str: diff --git a/tests/test_templates.py b/tests/test_templates.py index 38bbb255..01903c48 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -7,8 +7,8 @@ def test_alias_table(): - expected = """CREATE OR REPLACE VIEW target - AS SELECT * FROM source;""" + expected = """CREATE OR REPLACE VIEW target +AS SELECT * FROM source;""" query = templates.get_alias_table_query("source", "target") assert query == expected