Skip to content

Commit

Permalink
existing unit test rework, data generator
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Dec 20, 2023
1 parent 8ee5515 commit ff92be6
Show file tree
Hide file tree
Showing 21 changed files with 412 additions and 211 deletions.
2 changes: 2 additions & 0 deletions cumulus_library/.sqlfluff
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ schema_name = test_schema
source_table = source_table
source_id = source_id
table_cols = ["a","b"]
table_cols_types = ["varchar", "varchar"]
table_name = test_table
table_suffix = 2024_01_01_11_11_11
target_col_prefix = prefix
target_table = target_table
unnests = [{"source col": "g", "table_alias": "i", "row_alias":"j"}, {"source col": "k", "table_alias": "l", "row_alias":"m"},]
Expand Down
2 changes: 0 additions & 2 deletions cumulus_library/base_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@ def prepare_queries(self, cursor: object, schema: str, *args, **kwargs):
:param cursor: A PEP-249 compatible cursor
:param schema: A schema name
:param db_type: The db system being used (only relevant for db-specific
query construction)
"""
raise NotImplementedError

Expand Down
24 changes: 7 additions & 17 deletions cumulus_library/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,13 @@ def clean_study(
def clean_and_build_study(
self,
target: PosixPath,
export_dir: PosixPath,
stats_build: bool,
continue_from: str = None,
) -> None:
"""Recreates study views/tables
:param target: A PosixPath to the study directory
"""

studyparser = StudyManifestParser(target, self.data_path)
try:
if not continue_from:
Expand Down Expand Up @@ -140,6 +138,8 @@ def clean_and_build_study(
data_path=self.data_path,
)
self.update_transactions(studyparser.get_study_prefix(), "finished")
except SystemExit as exit:
raise exit
except Exception as e:
self.update_transactions(studyparser.get_study_prefix(), "error")
raise e
Expand All @@ -156,7 +156,7 @@ def run_single_table_builder(
self.cursor, self.schema_name, table_builder_name, self.verbose
)

def clean_and_build_all(self, study_dict: Dict, export_dir: PosixPath) -> None:
def clean_and_build_all(self, study_dict: Dict, stats_build: bool) -> None:
"""Builds views for all studies.
NOTE: By design, this method will always exclude the `template` study dir,
Expand All @@ -167,10 +167,10 @@ def clean_and_build_all(self, study_dict: Dict, export_dir: PosixPath) -> None:
study_dict = dict(study_dict)
study_dict.pop("template")
for precursor_study in ["vocab", "core"]:
self.clean_and_build_study(study_dict[precursor_study], export_dir)
self.clean_and_build_study(study_dict[precursor_study], stats_build)
study_dict.pop(precursor_study)
for key in study_dict:
self.clean_and_build_study(study_dict[key])
self.clean_and_build_study(study_dict[key], stats_build)

### Data exporters
def export_study(self, target: PosixPath, data_path: PosixPath) -> None:
Expand Down Expand Up @@ -267,7 +267,7 @@ def run_cli(args: Dict):
# all other actions require connecting to AWS
else:
db_backend = create_db_backend(args)
builder = StudyBuilder(db_backend, args["data_path"])
builder = StudyBuilder(db_backend, data_path=args.get("data_path", None))
if args["verbose"]:
builder.verbose = True
print("Testing connection to database...")
Expand All @@ -293,9 +293,7 @@ def run_cli(args: Dict):
)
elif args["action"] == "build":
if "all" in args["target"]:
builder.clean_and_build_all(
study_dict, args["export_dir"], args["stats_build"]
)
builder.clean_and_build_all(study_dict, args["stats_build"])
else:
for target in args["target"]:
if args["builder"]:
Expand All @@ -305,7 +303,6 @@ def run_cli(args: Dict):
else:
builder.clean_and_build_study(
study_dict[target],
args["data_path"],
args["stats_build"],
continue_from=args["continue_from"],
)
Expand All @@ -317,13 +314,6 @@ def run_cli(args: Dict):
for target in args["target"]:
builder.export_study(study_dict[target], args["data_path"])

# print(set(builder.cursor.execute("""SELECT table_name
# FROM information_schema.tables
# where table_name ilike '%_lib_%'
# or table_name ilike '%_psm_%'""").fetchall()))
# print(builder.cursor.execute("select * from psm_test__lib_statistics").fetchall())
# print(builder.cursor.execute("select * from psm_test__lib_transactions").fetchall())
# print(builder.cursor.execute("select * from psm_test__psm_encounter_covariate").fetchall())
db_backend.close()
# returning the builder for ease of unit testing
return builder
Expand Down
3 changes: 1 addition & 2 deletions cumulus_library/cli_parser.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""Manages configuration for argparse"""
import argparse

from cumulus_library.errors import CumulusLibraryError


def add_target_argument(parser: argparse.ArgumentParser) -> None:
"""Adds --target arg to a subparser"""
Expand Down Expand Up @@ -171,6 +169,7 @@ def get_parser() -> argparse.ArgumentParser:
add_table_builder_argument(build)
add_study_dir_argument(build)
add_verbose_argument(build)
add_data_path_argument(build)
add_db_config(build)
build.add_argument(
"--statistics",
Expand Down
9 changes: 5 additions & 4 deletions cumulus_library/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import json
import os
import sys
import warnings
from pathlib import Path
from typing import Optional, Protocol, Union

Expand Down Expand Up @@ -264,10 +265,10 @@ def create_db_backend(args: dict[str, str]) -> DatabaseBackend:
args["profile"],
database,
)
# if load_ndjson_dir:
# sys.exit(
# "Loading an ndjson dir is not supported with --db-type=athena (try duckdb)"
# )
if load_ndjson_dir:
warnings.warn(
"Loading an ndjson dir is not supported with --db-type=athena."
)
else:
raise ValueError(f"Unexpected --db-type value '{db_type}'")

Expand Down
43 changes: 20 additions & 23 deletions cumulus_library/protected_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,10 @@
class ProtectedTableBuilder(BaseTableBuilder):
display_text = "Creating/updating system tables..."

def prepare_queries(self, cursor: object, schema: str, study_name: str):
safe_timestamp = (
datetime.datetime.now()
.replace(microsecond=0)
.isoformat()
.replace(":", "_")
.replace("-", "_")
)
def prepare_queries(
self, cursor: object, schema: str, study_name: str, study_stats: dict
):
print("hi")
self.queries.append(
get_ctas_empty_query(
schema,
Expand All @@ -40,19 +36,20 @@ def prepare_queries(self, cursor: object, schema: str, study_name: str):
["varchar", "varchar", "varchar", "timestamp"],
)
)
self.queries.append(
get_ctas_empty_query(
schema,
f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}",
# same redundancy note about study_name, and also view_name, applies here
STATISTICS_COLS,
[
"varchar",
"varchar",
"varchar",
"varchar",
"varchar",
"timestamp",
],
if study_stats:
self.queries.append(
get_ctas_empty_query(
schema,
f"{study_name}__{PROTECTED_TABLES.STATISTICS.value}",
# same redundancy note about study_name, and also view_name, applies here
STATISTICS_COLS,
[
"varchar",
"varchar",
"varchar",
"varchar",
"varchar",
"timestamp",
],
)
)
)
75 changes: 36 additions & 39 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import os
import sys
import warnings

from pathlib import PosixPath
from dataclasses import dataclass
Expand Down Expand Up @@ -57,12 +58,12 @@ class PsmBuilder(BaseTableBuilder):

display_text = "Building PSM tables..."

def __init__(self, toml_config_path: str, export_path: PosixPath):
def __init__(self, toml_config_path: str, data_path: PosixPath):
"""Loads PSM job details from a PSM configuration file"""
super().__init__()
# We're stashing the toml path for error reporting later
self.toml_path = toml_config_path
self.export_path = export_path
self.data_path = data_path
try:
with open(self.toml_path, encoding="UTF-8") as file:
toml_config = toml.load(file)
Expand Down Expand Up @@ -166,23 +167,15 @@ def _create_covariate_table(
self.config.dependent_variable,
0,
)
cohort = pandas.concat([pos, neg])

# Replace table (if it exists)
# TODO - replace with timestamp prepended table in future PR
# drop = get_drop_view_table(
# f"{self.config.pos_source_table}_sampled_ids", "TABLE"
# )
# cursor.execute(drop)
cohort = pandas.concat([pos, neg])
ctas_query = get_ctas_query_from_df(
schema,
f"{self.config.pos_source_table}_sampled_ids_{table_suffix}",
cohort,
)
self.queries.append(ctas_query)
# TODO - replace with timestamp prepended table
# drop = get_drop_view_table(self.config.target_table, "TABLE")
# cursor.execute(drop)

dataset_query = get_create_covariate_table(
target_table=f"{self.config.target_table}_{table_suffix}",
pos_source_table=self.config.pos_source_table,
Expand All @@ -195,6 +188,7 @@ def _create_covariate_table(
count_table=self.config.count_table,
)
self.queries.append(dataset_query)
print(dataset_query)

def psm_plot_match(
self,
Expand Down Expand Up @@ -315,40 +309,43 @@ def generate_psm_analysis(
df = pandas.concat([df, encoded_df], axis=1)
df = df.drop(column, axis=1)
df = df.reset_index()

try:
psm = PsmPy(
df,
treatment=self.config.dependent_variable,
indx=self.config.primary_ref,
exclude=[],
)
# This function populates the psm.predicted_data element, which is required
# for things like the knn_matched() function call
# TODO: create graph from this data
psm.logistic_ps(balance=True)
# This function populates the psm.df_matched element
# TODO: flip replacement to false after increasing sample data size
# TODO: create graph from this data
psm.knn_matched(
matcher="propensity_logit",
replacement=True,
caliper=None,
drop_unmatched=True,
)
os.makedirs(self.export_path, exist_ok=True)
self.psm_plot_match(
psm,
save=True,
filename=self.export_path
/ f"{self.config.target_table}_{table_suffix}_propensity_match.png",
)
self.psm_effect_size_plot(
psm,
save=True,
filename=self.export_path
/ f"{self.config.target_table}_{table_suffix}_effect_size.png",
)

# we expect psmpy to autodrop non-matching values, so we'll surpress it
# mentioning workarounds for this behavior.
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
# This function populates the psm.predicted_data element, which is required
# for things like the knn_matched() function call
# TODO: create graph from this data
psm.logistic_ps(balance=True)
# This function populates the psm.df_matched element
# TODO: create graph from this data
psm.knn_matched(
matcher="propensity_logit",
replacement=False,
caliper=None,
drop_unmatched=True,
)
os.makedirs(self.data_path, exist_ok=True)
self.psm_plot_match(
psm,
save=True,
filename=self.data_path
/ f"{self.config.target_table}_{table_suffix}_propensity_match.png",
)
self.psm_effect_size_plot(
psm,
save=True,
filename=self.data_path
/ f"{self.config.target_table}_{table_suffix}_effect_size.png",
)
except ZeroDivisionError:
sys.exit(
"Encountered a divide by zero error during statistical graph generation. Try increasing your sample size."
Expand Down
Loading

0 comments on commit ff92be6

Please sign in to comment.