Skip to content

Commit

Permalink
PR feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Dec 26, 2023
1 parent 78eddb9 commit 92d4c96
Show file tree
Hide file tree
Showing 11 changed files with 179 additions and 124 deletions.
26 changes: 15 additions & 11 deletions cumulus_library/cli.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
#!/usr/bin/env python3
"""Utility for building/retrieving data views in AWS Athena"""

import datetime
import json
import os
import sys
import sysconfig

from datetime import datetime
from pathlib import Path, PosixPath
from typing import Dict, List, Optional

from rich.console import Console
from rich.table import Table

from cumulus_library import __version__
from cumulus_library import __version__, errors, helper
from cumulus_library.cli_parser import get_parser
from cumulus_library.databases import (
DatabaseBackend,
Expand Down Expand Up @@ -49,7 +49,7 @@ def update_transactions(self, prefix: str, status: str):
prefix,
__version__,
status,
datetime.now().replace(microsecond=0).isoformat(),
helper.get_utc_date(),
]
],
)
Expand All @@ -61,6 +61,7 @@ def clean_study(
self,
targets: List[str],
study_dict: Dict,
*,
stats_clean: bool,
prefix: bool = False,
) -> None:
Expand Down Expand Up @@ -102,6 +103,7 @@ def clean_study(
def clean_and_build_study(
self,
target: PosixPath,
*,
stats_build: bool,
continue_from: str = None,
) -> None:
Expand Down Expand Up @@ -145,7 +147,7 @@ def clean_and_build_study(
)
self.update_transactions(studyparser.get_study_prefix(), "finished")

except SystemExit as e:
except errors.StudyManifestFilesystemError as e:
# This should be thrown prior to any database connections, so
# skipping logging
raise e
Expand Down Expand Up @@ -178,10 +180,12 @@ def clean_and_build_all(self, study_dict: Dict, stats_build: bool) -> None:
study_dict = dict(study_dict)
study_dict.pop("template")
for precursor_study in ["vocab", "core"]:
self.clean_and_build_study(study_dict[precursor_study], stats_build)
self.clean_and_build_study(
study_dict[precursor_study], stats_build=stats_build
)
study_dict.pop(precursor_study)
for key in study_dict:
self.clean_and_build_study(study_dict[key], stats_build)
self.clean_and_build_study(study_dict[key], stats_build=stats_build)

### Data exporters
def export_study(self, target: PosixPath, data_path: PosixPath) -> None:
Expand Down Expand Up @@ -275,10 +279,10 @@ def run_cli(args: Dict):
elif args["action"] == "upload":
upload_files(args)

# all other actions require connecting to AWS
# all other actions require connecting to the database
else:
db_backend = create_db_backend(args)
builder = StudyBuilder(db_backend, data_path=args.get("data_path", None))
builder = StudyBuilder(db_backend, data_path=args.get("data_path"))
if args["verbose"]:
builder.verbose = True
print("Testing connection to database...")
Expand All @@ -299,8 +303,8 @@ def run_cli(args: Dict):
builder.clean_study(
args["target"],
study_dict,
args["stats_clean"],
args["prefix"],
stats_clean=args["stats_clean"],
prefix=args["prefix"],
)
elif args["action"] == "build":
if "all" in args["target"]:
Expand All @@ -314,7 +318,7 @@ def run_cli(args: Dict):
else:
builder.clean_and_build_study(
study_dict[target],
args["stats_build"],
stats_build=args["stats_build"],
continue_from=args["continue_from"],
)

Expand Down
6 changes: 2 additions & 4 deletions cumulus_library/databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import datetime
import json
import os
import warnings
import sys
from pathlib import Path
from typing import Optional, Protocol, Union

Expand Down Expand Up @@ -265,9 +265,7 @@ def create_db_backend(args: dict[str, str]) -> DatabaseBackend:
database,
)
if load_ndjson_dir:
warnings.warn(
"Loading an ndjson dir is not supported with --db-type=athena."
)
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.")
else:
raise ValueError(f"Unexpected --db-type value '{db_type}'")

Expand Down
18 changes: 10 additions & 8 deletions cumulus_library/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,21 @@

class CumulusLibraryError(Exception):
"""
Package level error
"""


class CumulusLibrarySchemaError(Exception):
"""
Package level error
Generic package level error
"""


class CountsBuilderError(Exception):
"""Basic error for CountsBuilder"""


class StudyManifestFilesystemError(Exception):
"""Errors related to files on disk in StudyManifestParser"""


class StudyManifestParsingError(Exception):
"""Basic error for StudyManifestParser"""
"""Errors related to manifest parsing in StudyManifestParser"""


class StudyManifestQueryError(Exception):
"""Errors related to data queries from StudyManifestParser"""
14 changes: 14 additions & 0 deletions cumulus_library/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
""" Collection of small commonly used utility functions """
import datetime
import os
import json
from typing import List
Expand Down Expand Up @@ -66,3 +67,16 @@ def get_progress_bar(**kwargs) -> progress.Progress:
progress.TimeRemainingColumn(elapsed_when_finished=True),
**kwargs,
)


def get_utc_date() -> datetime.datetime:
return (
datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0).isoformat()
)


def get_tablename_safe_iso_timestamp() -> str:
"""formats a timestamp to remove sql unallowed characters in table names"""
iso_timestamp = get_utc_date()
safe_timestamp = iso_timestamp.replace(":", "_").replace("-", "_").replace("+", "_")
return safe_timestamp
25 changes: 13 additions & 12 deletions cumulus_library/protected_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
)

TRANSACTIONS_COLS = ["study_name", "library_version", "status", "event_time"]
TRANSACTION_COLS_TYPES = ["varchar", "varchar", "varchar", "timestamp"]
# while it may seem redundant, study_name and view_name are included as a column for
# ease of constructing a view of multiple transaction tables
STATISTICS_COLS = [
"study_name",
"library_version",
Expand All @@ -14,6 +17,14 @@
"view_name",
"created_on",
]
STATISTICS_COLS_TYPES = [
"varchar",
"varchar",
"varchar",
"varchar",
"varchar",
"timestamp",
]


class ProtectedTableBuilder(BaseTableBuilder):
Expand All @@ -28,26 +39,16 @@ def prepare_queries(
get_ctas_empty_query(
schema,
f"{study_name}__{ProtectedTables.TRANSACTIONS.value}",
# while it may seem redundant, study name is included for ease
# of constructing a view of multiple transaction tables
TRANSACTIONS_COLS,
["varchar", "varchar", "varchar", "timestamp"],
TRANSACTION_COLS_TYPES,
)
)
if study_stats:
self.queries.append(
get_ctas_empty_query(
schema,
f"{study_name}__{ProtectedTables.STATISTICS.value}",
# same redundancy note about study_name, and also view_name, applies here
STATISTICS_COLS,
[
"varchar",
"varchar",
"varchar",
"varchar",
"varchar",
"timestamp",
],
STATISTICS_COLS_TYPES,
)
)
17 changes: 10 additions & 7 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ class PsmConfig:
These values should be read in from a toml configuration file.
See docs/statistics/propensity-score-matching.md for an example with details about
the expected values for these fields.
A word of caution about sampling: the assumptions around PSM analysis require
that any sampling should not use replacement, so do not turn on panda's dataframe
replacement. This will mean that very small population sizes (i.e. < 20ish)
may cause errors to be generated.
"""

classification_json: str
Expand Down Expand Up @@ -122,7 +127,7 @@ def _get_sampled_ids(
df = cursor.execute(query).as_pandas()
df = df.sort_values(by=[self.config.primary_ref])
df = (
df.sample(n=sample_size, random_state=self.config.seed, replace=False)
df.sample(n=sample_size, random_state=self.config.seed)
.sort_values(by=[self.config.primary_ref])
.reset_index()
.drop("index", axis=1)
Expand Down Expand Up @@ -181,7 +186,6 @@ def _create_covariate_table(
count_table=self.config.count_table,
)
self.queries.append(dataset_query)
print(dataset_query)

def psm_plot_match(
self,
Expand Down Expand Up @@ -261,10 +265,11 @@ def psm_effect_size_plot(
def generate_psm_analysis(
self, cursor: DatabaseCursor, schema: str, table_suffix: str
):
stats_table = f"{self.config.target_table}_{table_suffix}"
"""Runs PSM statistics on generated tables"""
cursor.execute(
f"""CREATE OR REPLACE VIEW {self.config.target_table}
AS SELECT * FROM {self.config.target_table}_{table_suffix}"""
AS SELECT * FROM {stats_table}"""
)
df = cursor.execute(f"SELECT * FROM {self.config.target_table}").as_pandas()
symptoms_dict = self._get_symptoms_dict(self.config.classification_json)
Expand Down Expand Up @@ -327,14 +332,12 @@ def generate_psm_analysis(
self.psm_plot_match(
psm,
save=True,
filename=self.data_path
/ f"{self.config.target_table}_{table_suffix}_propensity_match.png",
filename=self.data_path / f"{stats_table}_propensity_match.png",
)
self.psm_effect_size_plot(
psm,
save=True,
filename=self.data_path
/ f"{self.config.target_table}_{table_suffix}_effect_size.png",
filename=self.data_path / f"{stats_table}_effect_size.png",
)
except ZeroDivisionError:
sys.exit(
Expand Down
Loading

0 comments on commit 92d4c96

Please sign in to comment.