Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff update #272

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
default_install_hook_types: [pre-commit, pre-push]
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4 # if you update this, also update pyproject.toml
rev: v0.5.5 # if you update this, also update pyproject.toml
hooks:
- name: Ruff formatting
id: ruff-format
Expand Down
14 changes: 4 additions & 10 deletions cumulus_library/actions/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def _load_and_execute_builder(
table_builder_class = table_builder_subclasses[0]
table_builder = table_builder_class()
if write_reference_sql:
table_builder.prepare_queries(
config=config, manifest=manifest, parser=db_parser
)
table_builder.prepare_queries(config=config, manifest=manifest, parser=db_parser)
table_builder.comment_queries(doc_str=doc_str)
new_filename = pathlib.Path(f"{filename}").stem + ".sql"
table_builder.write_queries(
Expand Down Expand Up @@ -191,7 +189,7 @@ def run_statistics_builders(
existing_stats = (
config.db.cursor()
.execute(
"SELECT view_name FROM "
"SELECT view_name FROM " # noqa: S608
f"{manifest.get_study_prefix()}__{enums.ProtectedTables.STATISTICS.value}"
)
.fetchall()
Expand Down Expand Up @@ -273,9 +271,7 @@ def build_study(
"""
queries = []
for file in manifest.get_sql_file_list(continue_from):
for query in base_utils.parse_sql(
base_utils.load_text(f"{manifest._study_path}/{file}")
):
for query in base_utils.parse_sql(base_utils.load_text(f"{manifest._study_path}/{file}")):
queries.append([query, file])
if len(queries) == 0:
return []
Expand Down Expand Up @@ -377,9 +373,7 @@ def _execute_build_queries(
"start with a string like `study_prefix__`.",
)
try:
with base_utils.query_console_output(
config.verbose, query[0], progress, task
):
with base_utils.query_console_output(config.verbose, query[0], progress, task):
cursor.execute(query[0])
except Exception as e: # pylint: disable=broad-exception-caught
_query_error(
Expand Down
6 changes: 2 additions & 4 deletions cumulus_library/actions/cleaner.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_unprotected_stats_view_table(
protected_list = cursor.execute(
f"""SELECT {artifact_type.lower()}_name
FROM {drop_prefix}{enums.ProtectedTables.STATISTICS.value}
WHERE study_name = '{display_prefix}'"""
WHERE study_name = '{display_prefix}'""" # noqa: S608
).fetchall()
for protected_tuple in protected_list:
if protected_tuple in db_contents:
Expand All @@ -66,9 +66,7 @@ def get_unprotected_stats_view_table(
# this check handles athena reporting views as also being tables,
# so we don't waste time dropping things that don't exist
if artifact_type == "TABLE":
if not any(
db_row_tuple[0] in iter_q_and_t for iter_q_and_t in unprotected_list
):
if not any(db_row_tuple[0] in iter_q_and_t for iter_q_and_t in unprotected_list):
unprotected_list.append([db_row_tuple[0], artifact_type])
else:
unprotected_list.append([db_row_tuple[0], artifact_type])
Expand Down
10 changes: 3 additions & 7 deletions cumulus_library/actions/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,7 @@ def reset_counts_exports(
def _write_chunk(writer, chunk, arrow_schema):
writer.write(
pyarrow.Table.from_pandas(
chunk.sort_values(
by=list(chunk.columns), ascending=False, na_position="first"
),
chunk.sort_values(by=list(chunk.columns), ascending=False, na_position="first"),
preserve_index=False,
schema=arrow_schema,
)
Expand Down Expand Up @@ -70,11 +68,9 @@ def export_study(
table_list,
description=f"Exporting {manifest.get_study_prefix()} data...",
):
query = f"SELECT * FROM {table}"
query = f"SELECT * FROM {table}" # noqa: S608
query = base_utils.update_query_if_schema_specified(query, manifest)
dataframe_chunks, db_schema = config.db.execute_as_pandas(
query, chunksize=chunksize
)
dataframe_chunks, db_schema = config.db.execute_as_pandas(query, chunksize=chunksize)
path.mkdir(parents=True, exist_ok=True)
arrow_schema = pyarrow.schema(config.db.col_pyarrow_types_from_sql(db_schema))
with parquet.ParquetWriter(f"{path}/{table}.parquet", arrow_schema) as p_writer:
Expand Down
13 changes: 4 additions & 9 deletions cumulus_library/apis/umls.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,7 @@ def get_vsac_valuesets(
included_records = all_responses[0].get("compose", {}).get("include", [])
for record in included_records:
if "valueSet" in record:
valueset = self.get_vsac_valuesets(
action=action, url=record["valueSet"][0]
)
valueset = self.get_vsac_valuesets(action=action, url=record["valueSet"][0])
all_responses.append(valueset[0])
return all_responses

Expand Down Expand Up @@ -143,24 +141,21 @@ def download_umls_files(
"url": file_meta["downloadUrl"],
"apiKey": self.api_key,
}
download_res = requests.get(
download_res = requests.get( # noqa: S113
"https://uts-ws.nlm.nih.gov/download", params=download_payload, stream=True
)

with open(path / file_meta["fileName"], "wb") as f:
chunks_read = 0
with base_utils.get_progress_bar() as progress:
task = progress.add_task(
f"Downloading {file_meta['fileName']}", total=None
)
task = progress.add_task(f"Downloading {file_meta['fileName']}", total=None)
for chunk in download_res.iter_content(chunk_size=1024):
f.write(chunk)
chunks_read += 1
progress.update(
task,
description=(
f"Downloading {file_meta['fileName']}: "
f"{chunks_read/1000} MB"
f"Downloading {file_meta['fileName']}: " f"{chunks_read/1000} MB"
),
)
if unzip:
Expand Down
18 changes: 5 additions & 13 deletions cumulus_library/base_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,17 +56,13 @@ def execute_queries(
table_names = []
for query in self.queries:
# Get the first non-whitespace word after create table
table_name = re.search(
'(?i)(?<=create table )(([a-zA-Z0-9_".-]+))', query
)
table_name = re.search('(?i)(?<=create table )(([a-zA-Z0-9_".-]+))', query)

if table_name:
if table_name[0] == "IF":
# Edge case - if we're doing an empty conditional CTAS creation,
# we need to run a slightly different regex
table_name = re.search(
'(?i)(?<=not exists )(([a-zA-Z0-9_".-]+))', query
)
table_name = re.search('(?i)(?<=not exists )(([a-zA-Z0-9_".-]+))', query)

table_name = table_name[0]
table_names.append(table_name)
Expand All @@ -81,20 +77,16 @@ def execute_queries(
for query in self.queries:
try:
query = base_utils.update_query_if_schema_specified(query, manifest)
with base_utils.query_console_output(
config.verbose, query, progress, task
):
with base_utils.query_console_output(config.verbose, query, progress, task):
cursor.execute(query)
except Exception as e: # pylint: disable=broad-exception-caught
sys.exit(
"An error occured executing this query:\n----\n"
f"{query}\n----\n"
f"{e}"
"An error occured executing this query:\n----\n" f"{query}\n----\n" f"{e}"
)

self.post_execution(config, *args, **kwargs)

def post_execution( # noqa: B027 - this looks like, but is not, an abstract method
def post_execution(
self,
config: base_utils.StudyConfig,
*args,
Expand Down
8 changes: 3 additions & 5 deletions cumulus_library/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def _log_table(
# Migrating logging tables
if "lib_transactions" in table_name:
cols = cursor.execute(
"SELECT column_name FROM information_schema.columns "
"SELECT column_name FROM information_schema.columns " # noqa: S608
f"WHERE table_name ='{table_name}' "
f"AND table_schema ='{db_schema}'"
).fetchall()
Expand All @@ -109,13 +109,11 @@ def _log_table(
alter_query = ""
if isinstance(config.db, databases.AthenaDatabaseBackend):
alter_query = (
f"ALTER TABLE {db_schema}.{table_name} "
"ADD COLUMNS(message string)"
f"ALTER TABLE {db_schema}.{table_name} " "ADD COLUMNS(message string)"
)
elif isinstance(config.db, databases.DuckDatabaseBackend):
alter_query = (
f"ALTER TABLE {db_schema}.{table_name} "
"ADD COLUMN message varchar"
f"ALTER TABLE {db_schema}.{table_name} " "ADD COLUMN message varchar"
)
cursor.execute(alter_query)
cursor.execute(query)
Expand Down
31 changes: 11 additions & 20 deletions cumulus_library/statistics/psm.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ def _create_covariate_table(
"""Creates a covariate table from the loaded toml config"""
# checks for primary & link ref being the same
source_refs = list({self.config.primary_ref, self.config.count_ref} - {None})
pos_query = psm_templates.get_distinct_ids(
source_refs, self.config.pos_source_table
)
pos_query = psm_templates.get_distinct_ids(source_refs, self.config.pos_source_table)
pos = self._get_sampled_ids(
cursor,
schema,
Expand Down Expand Up @@ -187,8 +185,8 @@ def psm_plot_match(
Title="Side by side matched controls",
Ylabel="Number of patients",
Xlabel="propensity logit",
names=["positive_cohort", "negative_cohort"], # noqa: B006
colors=["#E69F00", "#56B4E9"], # noqa: B006
names=None,
colors=None,
save=True,
filename="propensity_match.png",
):
Expand All @@ -199,6 +197,8 @@ def psm_plot_match(
and passing in the psm object instead of assuming a call from inside
the PsmPy class.
"""
names = names or ["positive_cohort", "negative_cohort"]
colors = colors or ["#E69F00", "#56B4E9"]
dftreat = psm.df_matched[psm.df_matched[psm.treatment] == 1]
dfcontrol = psm.df_matched[psm.df_matched[psm.treatment] == 0]
x1 = dftreat[matched_entity]
Expand All @@ -217,7 +217,7 @@ def psm_plot_match(
def psm_effect_size_plot(
self,
psm,
title="Standardized Mean differences accross covariates before and after matching", # noqa: E501
title="Standardized Mean differences accross covariates before and after matching",
before_color="#FCB754",
after_color="#3EC8FB",
save=False,
Expand All @@ -239,9 +239,7 @@ def psm_effect_size_plot(
for cl in psm.xvars:
data.append([cl, "before", cohenD(df_preds_b4_float, psm.treatment, cl)])
data.append([cl, "after", cohenD(df_preds_after_float, psm.treatment, cl)])
psm.effect_size = pandas.DataFrame(
data, columns=["Variable", "matching", "Effect Size"]
)
psm.effect_size = pandas.DataFrame(data, columns=["Variable", "matching", "Effect Size"])
sns.set_style("white")
sns_plot = sns.barplot(
data=psm.effect_size,
Expand All @@ -260,17 +258,13 @@ def generate_psm_analysis(
):
stats_table = f"{self.config.target_table}_{table_suffix}"
"""Runs PSM statistics on generated tables"""
cursor.execute(
base_templates.get_alias_table_query(stats_table, self.config.target_table)
)
cursor.execute(base_templates.get_alias_table_query(stats_table, self.config.target_table))
df = cursor.execute(
base_templates.get_select_all_query(self.config.target_table)
).as_pandas()
symptoms_dict = self._get_symptoms_dict(self.config.classification_json)
for dependent_variable, codes in symptoms_dict.items():
df[dependent_variable] = df["code"].apply(
lambda x, codes=codes: 1 if x in codes else 0
)
df[dependent_variable] = df["code"].apply(lambda x, codes=codes: 1 if x in codes else 0)
df = df.drop(columns="code")
# instance_count present but unused for PSM if table contains a count_ref input
# (it's intended for manual review)
Expand Down Expand Up @@ -342,13 +336,10 @@ def generate_psm_analysis(
)
except ValueError:
sys.exit(
"Encountered a value error during KNN matching. Try increasing "
"your sample size."
"Encountered a value error during KNN matching. Try increasing " "your sample size."
)

def prepare_queries(
self, config: base_utils.StudyConfig, *args, table_suffix: str, **kwargs
):
def prepare_queries(self, config: base_utils.StudyConfig, *args, table_suffix: str, **kwargs):
self._create_covariate_table(config.db.cursor(), config.schema, table_suffix)

def post_execution(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_system_pairs(output_table_name: str, code_system_tables: list) -> str:
for column in table["column_hierarchy"]:
unnest_layer = ".".join(x for x in [unnest_layer, column[0]] if x)
display_col = ".".join(x for x in [display_col, column[0]] if x)
if column[1] == list:
if column[1] is list:
squashed_hierarchy.append((unnest_layer, list))
unnest_layer = ""
if unnest_layer != "":
Expand Down
Loading
Loading