Skip to content

Commit

Permalink
PSM table generation
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Nov 30, 2023
1 parent f2aafcd commit 87d067a
Show file tree
Hide file tree
Showing 11 changed files with 540 additions and 4 deletions.
2 changes: 1 addition & 1 deletion cumulus_library/.sqlfluff
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dialect = athena
sql_file_exts = .sql,.sql.jinja
# this rule overfires on athena nested arrays
exclude_rules=references.from,structure.column_order,aliasing.unused
max_line_length = 88
max_line_length = 90

[sqlfluff:indentation]
template_blocks_indent = false
Expand Down
7 changes: 6 additions & 1 deletion cumulus_library/base_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def prepare_queries(self, cursor: object, schema: str):
"""
raise NotImplementedError

@final
# 🚨🚨🚨 WARNING: 🚨🚨🚨 in 99% of cases, subclasses should *not* re-implement
# execute_queries.

# If you know what you are doing, you can attempt to override it, but it is
# strongly recommended you invoke this as is via a super() call, and then
# run code before or after that as makes sense for your use case.
def execute_queries(
self,
cursor: DatabaseCursor,
Expand Down
10 changes: 9 additions & 1 deletion cumulus_library/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)"""
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)
By convention, to maintain this as a relatively light wrapper layer, if you have
to chose between a convenience function in a specific library (as an example, the
[pyathena to_sql function](https://github.com/laughingman7743/PyAthena/#to-sql))
or using raw sql directly in some form, you should do the latter. This not a law;
if there's a compelling reason to do so, just make sure you add an appropriate
wrapper method in one of DatabaseCursor or DatabaseBackend.
"""

import abc
import datetime
Expand Down
230 changes: 230 additions & 0 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
#

import numpy as np
import pandas
import toml

# from psmpy import PsmPy
# from psmpy.functions import cohenD
from psmpy.plotting import *

import json
from pathlib import PosixPath
from dataclasses import dataclass

from cumulus_library.cli import StudyBuilder
from cumulus_library.databases import AthenaDatabaseBackend as AthenaDb
from cumulus_library.base_table_builder import BaseTableBuilder
from cumulus_library.template_sql.templates import (
get_ctas_query_from_df,
get_drop_view_table,
)
from cumulus_library.template_sql.statistics.psm_templates import (
get_distinct_ids,
get_create_covariate_table,
)
import os


@dataclass
class PsmConfig:
"""Provides expected values for PSM execution"""

pos_source_table: str
neg_source_table: str
target_table: str
primary_ref: str
count_ref: str
count_table: str
dependent_variable: str
pos_sample_size: int
neg_sample_size: int
join_cols_by_table: dict[str, dict]
seed: int = 1234567890


class PsmBuilder(BaseTableBuilder):
display_text = "Building PSM tables..."

def __init__(self, toml_config_path: str):
with open(toml_config_path, encoding="UTF-8") as file:
toml_config = toml.load(file)
self.config = PsmConfig(
pos_source_table=toml_config["pos_source_table"],
neg_source_table=toml_config["neg_source_table"],
target_table=toml_config["target_table"],
primary_ref=toml_config["primary_ref"],
dependent_variable=toml_config["dependent_variable"],
pos_sample_size=toml_config["pos_sample_size"],
neg_sample_size=toml_config["neg_sample_size"],
join_cols_by_table=toml_config["join_cols_by_table"],
count_ref=toml_config.get("count_ref", None),
count_table=toml_config.get("count_table", None),
seed=toml_config.get("seed", None),
)
super().__init__()

def _get_symptoms_dict(self, path: str) -> dict:
with open(path) as f:
symptoms = json.load(f)
return symptoms

# Todo: replace cusror object with new dbcursor object

def _get_sampled_ids(
self,
cursor: object,
schema: str,
query: str,
sample_size: int,
label: str,
is_positive: bool,
):
"""Creates a table containing randomly sampled patients for PSM analysis
To use this, it is assumed you have already identified a cohort of positively
IDed patients as a manual process.
TODO: recommend a name
"""
df = cursor.execute(query).as_pandas()
df = (
df.sort_values(by=list(df.columns))
.reset_index()
.sample(n=sample_size, random_state=self.config.seed)
)
df[label] = is_positive
return df

def _create_covariate_table(self, cursor: object, schema: str):
# checks for primary & link ref being the same
source_refs = list({self.config.primary_ref, self.config.count_ref})

# Sample a set of ids inside/outside of the cohort
pos_query = get_distinct_ids(source_refs, self.config.pos_source_table)
pos = self._get_sampled_ids(
cursor,
schema,
pos_query,
self.config.pos_sample_size,
self.config.dependent_variable,
1,
)
neg_query = get_distinct_ids(
source_refs,
self.config.neg_source_table,
join_id=self.config.primary_ref,
filter_table=self.config.pos_source_table,
)
neg = self._get_sampled_ids(
cursor,
schema,
neg_query,
self.config.neg_sample_size,
self.config.dependent_variable,
0,
)
cohort = pandas.concat([pos, neg])

# Replace table (if it exists)
# drop = get_drop_view_table(f"{self.config.pos_source_table}_sampled_ids", 'TABLE')
# cursor.execute(drop)
ctas_query = get_ctas_query_from_df(
schema, f"{self.config.pos_source_table}_sampled_ids", cohort
)
self.queries.append(ctas_query)
# drop = get_drop_view_table(self.config.target_table, 'TABLE')
# cursor.execute(drop)
dataset_query = get_create_covariate_table(
target_table=self.config.target_table,
pos_source_table=self.config.pos_source_table,
neg_source_table=self.config.neg_source_table,
primary_ref=self.config.primary_ref,
dependent_variable=self.config.dependent_variable,
join_cols_by_table=self.config.join_cols_by_table,
count_ref=self.config.count_ref,
count_table=self.config.count_table,
)
self.queries.append(dataset_query)

def generate_psm_analysis(self, cursor: object, schema: str):
df = cursor.execute(f"select * from {self.config.target_table}").as_pandas()
symptoms_dict = self._get_symptoms_dict(
"../../tests/test_data/psm/symptoms.json"
)
for dependent_variable, codes in symptoms_dict.items():
df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0)
df = df.drop(columns="code")
df = df.drop(columns="instance_count")
for column in ["gender", "race"]:
encoded_df = pandas.get_dummies(df[column])
df = pandas.concat([df, encoded_df], axis=1)
df = df.drop(column, axis=1)
df = df.reset_index()
try:
psm = PsmPy(
df,
treatment=self.config.dependent_variable,
indx=self.config.primary_ref,
exclude=[],
)
psm.logistic_ps(balance=True)
print(psm.predicted_data)
psm.knn_matched(
matcher="propensity_logit",
replacement=False,
caliper=None,
drop_unmatched=True,
)
print(psm.df_matched)
except ZeroDivisionError:
print(
"Encountered a divide by zero error during statistical graph generation. Try increasing your sample size."
)
except ValueError:
print(
"Encountered a value error during KNN matching. Try increasing your sample size."
)

def prepare_queries(self, cursor: object, schema: str):
self._create_covariate_table(cursor, schema)

def execute_queries(
self,
cursor: object,
schema: str,
verbose: bool,
drop_table: bool = False,
):
super().execute_queries(cursor, schema, verbose, drop_table)
self.generate_psm_analysis(cursor, schema)


if __name__ == "__main__":
arg_env_pairs = (
("profile", "CUMULUS_LIBRARY_PROFILE"),
("schema_name", "CUMULUS_LIBRARY_DATABASE"),
("workgroup", "CUMULUS_LIBRARY_WORKGROUP"),
("region", "CUMULUS_LIBRARY_REGION"),
("study_dir", "CUMULUS_LIBRARY_STUDY_DIR"),
("data_path", "CUMULUS_LIBRARY_DATA_PATH"),
("user", "CUMULUS_AGGREGATOR_USER"),
("id", "CUMULUS_AGGREGATOR_ID"),
("url", "CUMULUS_AGGREGATOR_URL"),
)
args = {}
read_env_vars = []
for pair in arg_env_pairs:
if env_val := os.environ.get(pair[1]):
if pair[0] == "study_dir":
args[pair[0]] = [env_val]
else:
args[pair[0]] = env_val
read_env_vars.append([pair[1], env_val])
database = AthenaDb(
args["region"], args["workgroup"], args["profile"], args["schema_name"]
)
builder = StudyBuilder(database)
psm = PsmBuilder(f"../../tests/test_data/psm/psm_config.toml")
psm.execute_queries(
database.pandas_cursor, builder.schema_name, False, drop_table=True
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
CREATE TABLE {{ target_table }} AS (
SELECT
DISTINCT sample_cohort."{{ primary_ref }}",
sample_cohort."{{ dependent_variable }}",
{%- if count_table %}
(
SELECT COUNT( DISTINCT {{ primary_ref }} )
FROM "{{ count_table }}"
WHERE sample_cohort."{{ count_ref }}" = "{{ count_table }}"."{{ count_ref }}"
--AND sample_cohort.enc_end_date >= "{{ count_table }}".recordeddate
) AS instance_count,
{%- endif %}
{%- for key in join_cols_by_table %}
{%- for column in join_cols_by_table[key]["included_cols"] %}
{%- if column|length == 1 %}
"{{ key }}"."{{ column[0] }}",
{%- else %}
"{{ key }}"."{{ column[0] }}" AS "{{ column[1]}}",
{%- endif %}
{%- endfor %}
{%- endfor %}
{{ neg_source_table }}.code
FROM "{{ pos_source_table }}_sampled_ids" AS sample_cohort,
"{{ neg_source_table }}",
{%- for key in join_cols_by_table %}
"{{ key }}"
{%- if not loop.last -%}
,
{%- endif -%}
{% endfor %}
WHERE sample_cohort."{{ primary_ref }}" = "{{ neg_source_table }}"."{{ primary_ref }}"
{%- for key in join_cols_by_table %}
AND sample_cohort."{{ join_cols_by_table[key]["join_id"] }}" = "{{ key }}"."{{ join_cols_by_table[key]["join_id"] }}"
{%- if not loop.last -%}
,
{%- endif -%}
{% endfor %}
-- AND c.recordeddate <= sample_cohort.enc_end_date
ORDER BY sample_cohort."{{ primary_ref }}"
)
14 changes: 14 additions & 0 deletions cumulus_library/template_sql/statistics/psm_distinct_ids.sql.jinja
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
SELECT DISTINCT
{%- for column in columns %}
"{{ source_table }}"."{{ column }}"
{%- if not loop.last -%}
,
{%- endif -%}
{%- endfor %}
FROM {{ source_table }}
{%- if join_id %}
WHERE "{{ source_table }}"."{{ join_id }}" NOT IN (
SELECT "{{ filter_table }}"."{{ join_id }}"
FROM {{ filter_table }}
)
{%- endif -%}
Loading

0 comments on commit 87d067a

Please sign in to comment.