Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSM table generation #150

Merged
merged 17 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion cumulus_library/.sqlfluff
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ dialect = athena
sql_file_exts = .sql,.sql.jinja
# this rule overfires on athena nested arrays
exclude_rules=references.from,structure.column_order,aliasing.unused
max_line_length = 88
max_line_length = 90

[sqlfluff:indentation]
template_blocks_indent = false
Expand All @@ -18,20 +18,30 @@ capitalisation_policy = upper
[sqlfluff:templater:jinja:context]
code_systems = ["http://snomed.info/sct", "http://hl7.org/fhir/sid/icd-10-cm"]
col_type_list = ["a string","b string"]
columns = ['a','b']
cc_columns = [{"name": "baz", "is_array": True}, {"name": "foobar", "is_array": False}]
cc_column = 'code'
code_system_tables = [{table_name":"hasarray","column_name":"acol","is_bare_coding":False,"is_array":True, "has_data": True},{"table_name":"noarray","column_name":"col","is_bare_coding":False,"is_array":False, "has_data": True}{"table_name":"bare","column_name":"bcol","is_bare_coding":True,"is_array":False, "has_data": True},{"table_name":"empty","column_name":"empty","is_bare_coding":False,"is_array":False, "has_data": False}]
column_name = 'bar'
conditions = ["1 > 0", "1 < 2"]
count_ref = count_ref
count_table = count_table
dataset = [["foo","foo"],["bar","bar"]]
dependent_variable = is_flu
ext_systems = ["omb", "text"]
field = 'column_name'
filter_table = filter_table
fhir_extension = fhir_extension
fhir_resource = patient
id = 'id'
join_cols_by_table = { "join_table": { "join_id": "enc_ref","included_cols": [["a"], ["b", "c"]]}}
join_id = subject_ref
medication_datasources = {"by_contained_ref" : True, "by_external_ref" : True}
neg_source_table = neg_source_table
output_table_name = 'created_table'
prefix = Test
primary_ref = encounter_ref
pos_source_table = pos_source_table
schema_name = test_schema
source_table = source_table
source_id = source_id
Expand Down
8 changes: 6 additions & 2 deletions cumulus_library/base_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re

from abc import ABC, abstractmethod
from typing import final

from cumulus_library.databases import DatabaseCursor
from cumulus_library.helper import get_progress_bar, query_console_output
Expand Down Expand Up @@ -33,7 +32,12 @@ def prepare_queries(self, cursor: object, schema: str):
"""
raise NotImplementedError

@final
# 🚨🚨🚨 WARNING: in 99% of cases, subclasses should *not* re-implement 🚨🚨🚨
# 🚨🚨🚨 execute_queries. 🚨🚨🚨

# If you know what you are doing, you can attempt to override it, but it is
# strongly recommended you invoke this as is via a super() call, and then
# run code before or after that as makes sense for your use case.
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
def execute_queries(
self,
cursor: DatabaseCursor,
Expand Down
27 changes: 26 additions & 1 deletion cumulus_library/databases.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)"""
"""Abstraction layers for supported database backends (e.g. AWS & DuckDB)

By convention, to maintain this as a relatively light wrapper layer, if you have
to chose between a convenience function in a specific library (as an example, the
[pyathena to_sql function](https://github.com/laughingman7743/PyAthena/#to-sql))
or using raw sql directly in some form, you should do the latter. This not a law;
if there's a compelling reason to do so, just make sure you add an appropriate
wrapper method in one of DatabaseCursor or DatabaseBackend.
"""

import abc
import datetime
Expand Down Expand Up @@ -47,6 +55,14 @@ def __init__(self, schema_name: str):
def cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database"""

@abc.abstractmethod
def pandas_cursor(self) -> DatabaseCursor:
"""Returns a connection to the backing database optimized for dataframes

If your database does not provide an optimized cursor, this should function the
same as a vanilla cursor.
"""

Comment on lines +58 to +65
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this is the change to the DB class I was mentioning, and I'm hoping that this comment explains why it's in here the way it is, but to be a bit more verbose about this: pyathena has a method that dramatically improves query execution when it's looking to return a dataframe - something about how they handle chunking under the hood. So, in context, when I'm passing a cursor to a method, I sometimes elect to specifically hand one of these pandas cursors off.

I did this while testing the PSM code (where the cursor is the entrypoint - we :could: rewrite table builders to take a Connection rather than a Cursor, but that's a big refactor by itself and this is already pretty gross), and in the future manifest parsing hook for this to come as a followon PR, I'm planning on specifying the pandas cursor for PSM invocation. The DuckDB version just returns a regular cursor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm fine with this change based on the constraint of "Cursor is the interface, not DatabaseBackend/Connection". Some thoughts around it though:

  • I'd like to see as_pandas added to the Cursor protocol we have, so that consumers of Library know it's contractually available. (See below for some commentary on this.)
  • I'd like to see execute_as_pandas dropped -- I only added that to avoid the need for extending cursors like this. But now we could simplify that interface.
  • The solution of creating an alias for as_pandas in the duckdb returned cursor is fine, but gives me pause because clever monkey-patching can be taken too far. 😄 If this setup gets more complicated, I might vote for a DuckCursor wrapper object that does similar kind of translations needed in future.
  • We really now have two kinds of Cursors - those for which as_pandas is available and those for which it isn't. What happens on a PyAthena normal cursor if you call as_pandas?
    • For our purposes, maybe AthenaDatabaseBackend should create a wrapper AthenaCursor object that throws an exception if you try to call as_pandas on the wrong cursor object.
    • Or even better probably, have two different Cursor protocols. One pandas-powered and one that isn't. That way method signatures would be clear about which cursor they expect to be handed. (if that is always clear?)
    • You could also add Cursor wrappers and a method like .get_database_backend() or something to give access to parent objects without introducing two different kinds of Cursors. But that's a little clunky in its own way. But may feel less clunky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

honestly - i think i like the idea of refactoring one way or another to get these more in line, i'm just trying to not do it as part of this PR for complexity reasons - we can maybe natter about the shape? some options, pulling on some of these threads:

  • I don't hate making a database connection the atomic unit, but it is probably going to touch the most things
  • as_pandas is, apparently, available as a util method that can be called on a pyathena cursor, so we could switch to that and keep the cursor space down to one per db. that might slot better into the execute_as_pandas paradigm
  • I think genereally a PEP cursor has a reference back to its connection, so maybe it's not the end of the world to have it get the database backend, though i think that's my least favorite of these.

@abc.abstractmethod
def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
"""Returns a pandas.DataFrame version of the results from the provided SQL"""
Expand Down Expand Up @@ -85,6 +101,9 @@ def __init__(self, region: str, workgroup: str, profile: str, schema_name: str):
def cursor(self) -> AthenaCursor:
return self.connection.cursor()

def pandas_cursor(self) -> AthenaPandasCursor:
return self.pandas_cursor

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
return self.pandas_cursor.execute(sql).as_pandas()

Expand All @@ -95,6 +114,8 @@ class DuckDatabaseBackend(DatabaseBackend):
def __init__(self, db_file: str):
super().__init__("main")
self.connection = duckdb.connect(db_file)
# Aliasing Athena's as_pandas to duckDB's df cast
setattr(duckdb.DuckDBPyConnection, "as_pandas", duckdb.DuckDBPyConnection.df)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved

# Paper over some syntax differences between Athena and DuckDB
self.connection.create_function(
Expand Down Expand Up @@ -150,6 +171,10 @@ def cursor(self) -> duckdb.DuckDBPyConnection:
# because then we'd have to re-register our json tables.
return self.connection

def pandas_cursor(self) -> duckdb.DuckDBPyConnection:
# Since this is not provided, return the vanilla cursor
return self.connection

def execute_as_pandas(self, sql: str) -> pandas.DataFrame:
# We call convert_dtypes here in case there are integer columns.
# Pandas will normally cast nullable-int as a float type unless
Expand Down
247 changes: 247 additions & 0 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
# Module for generating Propensity Score matching cohorts

import numpy as np
import pandas
import toml

from psmpy import PsmPy


import json
from pathlib import PosixPath
from dataclasses import dataclass

from cumulus_library.cli import StudyBuilder
from cumulus_library.databases import DatabaseCursor
from cumulus_library.base_table_builder import BaseTableBuilder
from cumulus_library.template_sql.templates import (
get_ctas_query_from_df,
get_drop_view_table,
)
from cumulus_library.template_sql.statistics.psm_templates import (
get_distinct_ids,
get_create_covariate_table,
)


@dataclass
class PsmConfig:
"""Provides expected values for PSM execution

These values should be read in from a toml configuration file.
See tests/test_data/psm/psm_config.toml for an example with details about
the expected values for these fields.
"""

classification_json: str
pos_source_table: str
neg_source_table: str
target_table: str
primary_ref: str
count_ref: str
count_table: str
dependent_variable: str
pos_sample_size: int
neg_sample_size: int
join_cols_by_table: dict[str, dict]
seed: int


class PsmBuilder(BaseTableBuilder):
"""TableBuilder for creating PSM tables"""

display_text = "Building PSM tables..."

def __init__(self, toml_config_path: str):
"""Loads PSM job details from a PSM configuration file"""
with open(toml_config_path, encoding="UTF-8") as file:
toml_config = toml.load(file)
self.config = PsmConfig(
classification_json=f"{PosixPath(toml_config_path).parent}/{toml_config['classification_json']}",
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
pos_source_table=toml_config["pos_source_table"],
neg_source_table=toml_config["neg_source_table"],
target_table=toml_config["target_table"],
primary_ref=toml_config["primary_ref"],
dependent_variable=toml_config["dependent_variable"],
pos_sample_size=toml_config["pos_sample_size"],
neg_sample_size=toml_config["neg_sample_size"],
join_cols_by_table=toml_config.get("join_cols_by_table", {}),
count_ref=toml_config.get("count_ref", None),
count_table=toml_config.get("count_table", None),
seed=toml_config.get("seed", 123),
)
super().__init__()
dogversioning marked this conversation as resolved.
Show resolved Hide resolved

def _get_symptoms_dict(self, path: str) -> dict:
"""convenience function for loading symptoms dictionaries from a json file"""
with open(path) as f:
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
symptoms = json.load(f)
return symptoms

def _get_sampled_ids(
self,
cursor: DatabaseCursor,
schema: str,
query: str,
sample_size: int,
dependent_variable: str,
is_positive: bool,
):
"""Creates a table containing randomly sampled patients for PSM analysis

To use this, it is assumed you have already identified a cohort of positively
IDed patients as a manual process.
:param cursor: A valid DatabaseCusror:
:param schema: the schema/database name where the data exists
:param query: a query generated from the psm_dsitinct_ids template
:param sample_size: the number of records to include in the random sample.
This should generally be >= 20.
:param dependent_variable: the name to use for your filtering column
:param is_positive: defines the value to be used for your filtering column
"""
df = cursor.execute(query).as_pandas()
df = (
df.sort_values(by=[self.config.primary_ref])
.reset_index()
.drop("index", axis=1)
)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
df = (
# TODO: remove replace behavior after increasing data sample size
Copy link
Contributor

@mikix mikix Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still valid? As a comment to a non-Matt, I'm not sure it's clear to me when that would be true. I would currently guess this is about unit tests. (Same comment above the call to knn_matched later)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok no this is about the PsmPy behavior vs valid statistical sampling techniques, which i'm intending to deal with in the followon PR, though you can talk me into it now; this was more about niceness for a reviewer.

replace=True allows the same record to be sampled multiple times if its below the floor value of 20. The floor value is set in PsmPy in a way which cannot be overridden; This is probably desirable from a math perspective but becomes harder when dealing with very small input sets. So this comes back to the side convo we had about data size - i could try to solve that here by touching ids and re-inserting our existing test data 3 or 4 times. But this value should be toggled before we ask anyone to use this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried to make this comment more sensible for lottery/bus reasons

df.sample(n=sample_size, random_state=self.config.seed, replace=True)
.sort_values(by=[self.config.primary_ref])
.reset_index()
.drop("index", axis=1)
)

df[dependent_variable] = is_positive
return df

def _create_covariate_table(self, cursor: DatabaseCursor, schema: str):
"""Creates a covariate table from the loaded toml config"""
# checks for primary & link ref being the same
source_refs = list(
{self.config.primary_ref, self.config.count_ref} - set([None])
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
)
pos_query = get_distinct_ids(source_refs, self.config.pos_source_table)
pos = self._get_sampled_ids(
cursor,
schema,
pos_query,
self.config.pos_sample_size,
self.config.dependent_variable,
1,
)
neg_query = get_distinct_ids(
source_refs,
self.config.neg_source_table,
join_id=self.config.primary_ref,
filter_table=self.config.pos_source_table,
)
neg = self._get_sampled_ids(
cursor,
schema,
neg_query,
self.config.neg_sample_size,
self.config.dependent_variable,
0,
)
cohort = pandas.concat([pos, neg])

# Replace table (if it exists)
# TODO - replace with timestamp prepended table
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
drop = get_drop_view_table(
f"{self.config.pos_source_table}_sampled_ids", "TABLE"
)
cursor.execute(drop)
ctas_query = get_ctas_query_from_df(
schema,
f"{self.config.pos_source_table}_sampled_ids",
cohort,
)
self.queries.append(ctas_query)
# TODO - replace with timestamp prepended table
drop = get_drop_view_table(self.config.target_table, "TABLE")
cursor.execute(drop)
dataset_query = get_create_covariate_table(
target_table=self.config.target_table,
pos_source_table=self.config.pos_source_table,
neg_source_table=self.config.neg_source_table,
primary_ref=self.config.primary_ref,
dependent_variable=self.config.dependent_variable,
join_cols_by_table=self.config.join_cols_by_table,
count_ref=self.config.count_ref,
count_table=self.config.count_table,
)
self.queries.append(dataset_query)

def generate_psm_analysis(self, cursor: object, schema: str):
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
"""Runs PSM statistics on generated tables"""
df = cursor.execute(f"select * from {self.config.target_table}").as_pandas()
symptoms_dict = self._get_symptoms_dict(self.config.classification_json)
for dependent_variable, codes in symptoms_dict.items():
df[dependent_variable] = df["code"].apply(lambda x: 1 if x in codes else 0)
df = df.drop(columns="code")
# instance_count present but unused for PSM if table contains a count_ref input
# (it's intended for manual review)
df = df.drop(columns="instance_count", errors="ignore")

columns = []
if self.config.join_cols_by_table is not None:
for table_key in self.config.join_cols_by_table:
for column in self.config.join_cols_by_table[table_key][
"included_cols"
]:
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
if len(column) == 2:
columns.append(column[1])
else:
columns.append(column[0])
dogversioning marked this conversation as resolved.
Show resolved Hide resolved

for column in columns:
encoded_df = pandas.get_dummies(df[column])
df = pandas.concat([df, encoded_df], axis=1)
df = df.drop(column, axis=1)
df = df.reset_index()
Comment on lines +224 to +228
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A block like this could use a comment about why it's doing what it's doing. My rough attempt: you're replacing each column with a dummy version of that column. But even after reading the pandas docs on get_dummies, I'm not 💯 on what that means.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah i'll add something - this is converting to a 1-hot encoding for all values of that column, basically pivoting the column values to new column headers.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with some explanitory text


try:
psm = PsmPy(
df,
treatment=self.config.dependent_variable,
indx=self.config.primary_ref,
exclude=[],
)
# This function populates the psm.predicted_data element, which is required
# for things like the knn_matched() function call
psm.logistic_ps(balance=True)
print(psm.predicted_data)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
# This function populates the psm.df_matched element
# TODO: flip replacement to false after increasing sample data size
psm.knn_matched(
matcher="propensity_logit",
replacement=True,
caliper=None,
drop_unmatched=True,
)
print(psm.df_matched)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
except ZeroDivisionError:
print(
"Encountered a divide by zero error during statistical graph generation. Try increasing your sample size."
)
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
except ValueError:
print(
"Encountered a value error during KNN matching. Try increasing your sample size."
)

def prepare_queries(self, cursor: object, schema: str):
self._create_covariate_table(cursor, schema)

def execute_queries(
self,
cursor: object,
schema: str,
verbose: bool,
drop_table: bool = False,
):
super().execute_queries(cursor, schema, verbose, drop_table)
self.comment_queries()
self.write_queries()
dogversioning marked this conversation as resolved.
Show resolved Hide resolved
self.generate_psm_analysis(cursor, schema)
Loading