Skip to content

Commit

Permalink
ruff format, noqa jinja autoescape
Browse files Browse the repository at this point in the history
  • Loading branch information
dogversioning committed Jul 29, 2024
1 parent b81967e commit 88eed26
Show file tree
Hide file tree
Showing 45 changed files with 110 additions and 321 deletions.
12 changes: 3 additions & 9 deletions cumulus_library/actions/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@

def _create_table_from_parquet(archive, file, study_name, config):
try:
parquet_path = pathlib.Path(
archive.extract(file), path=tempfile.TemporaryFile()
)
parquet_path = pathlib.Path(archive.extract(file), path=tempfile.TemporaryFile())
# While convenient to access, this exposes us to panda's type system,
# which is messy - this could be optionally be replaced by pyarrow if it
# becomes problematic.
Expand Down Expand Up @@ -56,13 +54,9 @@ def import_archive(config: base_utils.StudyConfig, *, archive_path: pathlib.Path
files = archive.namelist()
files = [file for file in files if file.endswith(".parquet")]
except zipfile.BadZipFile as e:
raise errors.StudyImportError(
f"File {archive_path} is not a valid archive."
) from e
raise errors.StudyImportError(f"File {archive_path} is not a valid archive.") from e
if not any("__" in file for file in files):
raise errors.StudyImportError(
f"File {archive_path} contains non-study parquet files."
)
raise errors.StudyImportError(f"File {archive_path} contains non-study parquet files.")
study_name = files[0].split("__")[0]
for file in files[1:]:
if file.split("__")[0] != study_name:
Expand Down
7 changes: 2 additions & 5 deletions cumulus_library/actions/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,7 @@ def upload_files(args: dict):
"""Wrapper to prep files & console output"""
if args["data_path"] is None:
sys.exit(
"No data directory provided - please provide a path to your"
"study export folder."
"No data directory provided - please provide a path to your" "study export folder."
)
file_paths = list(args["data_path"].glob("**/*.parquet"))
if args["target"]:
Expand All @@ -79,9 +78,7 @@ def upload_files(args: dict):
if not args["user"] or not args["id"]:
sys.exit("user/id not provided, please pass --user and --id")
try:
meta_version = next(
filter(lambda x: str(x).endswith("__meta_version.parquet"), file_paths)
)
meta_version = next(filter(lambda x: str(x).endswith("__meta_version.parquet"), file_paths))
version = str(read_parquet(meta_version)["data_package_version"][0])
file_paths.remove(meta_version)
except StopIteration:
Expand Down
4 changes: 1 addition & 3 deletions cumulus_library/base_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ def zip_dir(read_path, write_path, archive_name):
shutil.rmtree(read_path)


def update_query_if_schema_specified(
query: str, manifest: study_manifest.StudyManifest
):
def update_query_if_schema_specified(query: str, manifest: study_manifest.StudyManifest):
if manifest and manifest.get_dedicated_schema():
# External queries in athena require a schema to be specified already, so
# rather than splitting and ending up with a table name like
Expand Down
36 changes: 9 additions & 27 deletions cumulus_library/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@ def clean_and_build_study(
"""
manifest = study_manifest.StudyManifest(target, self.data_path)
try:
builder.run_protected_table_builder(
config=self.get_config(manifest), manifest=manifest
)
builder.run_protected_table_builder(config=self.get_config(manifest), manifest=manifest)
if not continue_from:
log_utils.log_transaction(
config=self.get_config(manifest),
Expand All @@ -111,9 +109,7 @@ def clean_and_build_study(
config=self.get_config(manifest),
manifest=manifest,
)
builder.run_table_builder(
config=self.get_config(manifest), manifest=manifest
)
builder.run_table_builder(config=self.get_config(manifest), manifest=manifest)

else:
log_utils.log_transaction(
Expand All @@ -127,9 +123,7 @@ def clean_and_build_study(
manifest=manifest,
continue_from=continue_from,
)
builder.run_counts_builders(
config=self.get_config(manifest), manifest=manifest
)
builder.run_counts_builders(config=self.get_config(manifest), manifest=manifest)
builder.run_statistics_builders(
config=self.get_config(manifest),
manifest=manifest,
Expand Down Expand Up @@ -170,9 +164,7 @@ def run_matching_table_builder(
)

### Data exporters
def export_study(
self, target: pathlib.Path, data_path: pathlib.Path, archive: bool
) -> None:
def export_study(self, target: pathlib.Path, data_path: pathlib.Path, archive: bool) -> None:
"""Exports aggregates defined in a manifest
:param target: A path to the study directory
Expand All @@ -199,9 +191,7 @@ def generate_study_sql(
:keyword builder: Specify a single builder to generate sql from
"""
manifest = study_manifest.StudyManifest(target)
file_generator.run_generate_sql(
config=self.get_config(manifest), manifest=manifest
)
file_generator.run_generate_sql(config=self.get_config(manifest), manifest=manifest)

def generate_study_markdown(
self,
Expand All @@ -212,9 +202,7 @@ def generate_study_markdown(
:param target: A path to the study directory
"""
manifest = study_manifest.StudyManifest(target)
file_generator.run_generate_markdown(
config=self.get_config(manifest), manifest=manifest
)
file_generator.run_generate_markdown(config=self.get_config(manifest), manifest=manifest)


def get_abs_path(path: str) -> pathlib.Path:
Expand Down Expand Up @@ -324,9 +312,7 @@ def run_cli(args: dict):
elif args["action"] == "build":
for target in args["target"]:
if args["builder"]:
runner.run_matching_table_builder(
study_dict[target], args["builder"]
)
runner.run_matching_table_builder(study_dict[target], args["builder"])
else:
runner.clean_and_build_study(
study_dict[target],
Expand All @@ -350,9 +336,7 @@ def run_cli(args: dict):
if response.lower() != "y":
sys.exit()
for target in args["target"]:
runner.export_study(
study_dict[target], args["data_path"], args["archive"]
)
runner.export_study(study_dict[target], args["data_path"], args["archive"])

elif args["action"] == "import":
for archive in args["archive_path"]:
Expand All @@ -361,9 +345,7 @@ def run_cli(args: dict):

elif args["action"] == "generate-sql":
for target in args["target"]:
runner.generate_study_sql(
study_dict[target], builder=args["builder"]
)
runner.generate_study_sql(study_dict[target], builder=args["builder"])

elif args["action"] == "generate-md":
for target in args["target"]:
Expand Down
20 changes: 5 additions & 15 deletions cumulus_library/cli_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,7 @@ def get_parser() -> argparse.ArgumentParser:

# Database export

export = actions.add_parser(
"export", help="Generates files on disk from Athena tables/views"
)
export = actions.add_parser("export", help="Generates files on disk from Athena tables/views")
add_custom_option(export)
add_target_argument(export)
add_study_dir_argument(export)
Expand All @@ -245,9 +243,7 @@ def get_parser() -> argparse.ArgumentParser:

# Database import

importer = actions.add_parser(
"import", help="Recreates a study from an exported archive"
)
importer = actions.add_parser("import", help="Recreates a study from an exported archive")
add_db_config(importer)
add_verbose_argument(importer)
importer.add_argument(
Expand All @@ -258,15 +254,11 @@ def get_parser() -> argparse.ArgumentParser:
)
# Aggregator upload

upload = actions.add_parser(
"upload", help="Bulk uploads data to Cumulus aggregator"
)
upload = actions.add_parser("upload", help="Bulk uploads data to Cumulus aggregator")
add_data_path_argument(upload)
add_target_argument(upload)

upload.add_argument(
"--id", help="Site ID. Default is value of CUMULUS_AGGREGATOR_ID"
)
upload.add_argument("--id", help="Site ID. Default is value of CUMULUS_AGGREGATOR_ID")
upload.add_argument(
"--preview",
default=False,
Expand All @@ -281,9 +273,7 @@ def get_parser() -> argparse.ArgumentParser:
),
default="https://aggregator.smartcumulus.org/upload/",
)
upload.add_argument(
"--user", help="Cumulus user. Default is value of CUMULUS_AGGREGATOR_USER"
)
upload.add_argument("--user", help="Cumulus user. Default is value of CUMULUS_AGGREGATOR_USER")

# Generate a study's template-driven sql

Expand Down
17 changes: 4 additions & 13 deletions cumulus_library/databases/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,7 @@ def col_parquet_types_from_pandas(self, field_types: list) -> list:
match field:
case numpy.dtypes.ObjectDType():
output.append("STRING")
case (
pandas.core.arrays.integer.Int64Dtype()
| numpy.dtypes.Int64DType()
):
case pandas.core.arrays.integer.Int64Dtype() | numpy.dtypes.Int64DType():
output.append("INT")
case numpy.dtypes.Float64DType():
output.append("DOUBLE")
Expand Down Expand Up @@ -114,9 +111,7 @@ def col_pyarrow_types_from_sql(self, columns: list[tuple]) -> list:
case "timestamp":
output.append((column[0], pyarrow.timestamp("s")))
case _:
raise errors.CumulusLibraryError(
f"Unsupported SQL type '{column[1]}' found."
)
raise errors.CumulusLibraryError(f"Unsupported SQL type '{column[1]}' found.")
return output

def upload_file(
Expand All @@ -134,18 +129,14 @@ def upload_file(
s3_path = wg_conf["OutputLocation"]
bucket = "/".join(s3_path.split("/")[2:3])
key_prefix = "/".join(s3_path.split("/")[3:])
encryption_type = wg_conf.get("EncryptionConfiguration", {}).get(
"EncryptionOption", {}
)
encryption_type = wg_conf.get("EncryptionConfiguration", {}).get("EncryptionOption", {})
if encryption_type != "SSE_KMS":
raise errors.AWSError(
f"Bucket {bucket} has unexpected encryption type {encryption_type}."
"AWS KMS encryption is expected for Cumulus buckets"
)
kms_arn = wg_conf.get("EncryptionConfiguration", {}).get("KmsKey", None)
s3_key = (
f"{key_prefix}cumulus_user_uploads/{self.schema_name}/" f"{study}/{topic}"
)
s3_key = f"{key_prefix}cumulus_user_uploads/{self.schema_name}/" f"{study}/{topic}"
if not remote_filename:
remote_filename = file.name

Expand Down
4 changes: 1 addition & 3 deletions cumulus_library/databases/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def _recursively_validate(

return output

def validate_table_schema(
self, expected: dict[str, list], schema: list[tuple]
) -> dict:
def validate_table_schema(self, expected: dict[str, list], schema: list[tuple]) -> dict:
"""Public interface for investigating if fields are in a table schema.
expected is a dictionary of string column names to *something*:
Expand Down
4 changes: 1 addition & 3 deletions cumulus_library/databases/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def insert_tables(self, tables: dict[str, pyarrow.Table]) -> None:
self.connection.register(name, table)

@staticmethod
def _compat_array_join(
value: list[str | None] | None, delimiter: str | None
) -> str | None:
def _compat_array_join(value: list[str | None] | None, delimiter: str | None) -> str | None:
if value is None:
return None
if delimiter is None or delimiter == "None":
Expand Down
10 changes: 2 additions & 8 deletions cumulus_library/databases/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def read_ndjson_dir(path: str) -> dict[str, pyarrow.Table]:
"etl__completion_encounters",
]
for metadata_table in metadata_tables:
rows = list(
cumulus_fhir_support.read_multiline_json_from_dir(
f"{path}/{metadata_table}"
)
)
rows = list(cumulus_fhir_support.read_multiline_json_from_dir(f"{path}/{metadata_table}"))
if rows:
# Auto-detecting the schema works for these simple tables
all_tables[metadata_table] = pyarrow.Table.from_pylist(rows)
Expand Down Expand Up @@ -113,8 +109,6 @@ def create_db_backend(args: dict[str, str]) -> (base.DatabaseBackend, str):
if args.get("load_ndjson_dir"):
sys.exit("Loading an ndjson dir is not supported with --db-type=athena.")
else:
raise errors.CumulusLibraryError(
f"'{db_config.db_type}' is not a supported database."
)
raise errors.CumulusLibraryError(f"'{db_config.db_type}' is not a supported database.")

return (backend, schema_name)
6 changes: 2 additions & 4 deletions cumulus_library/protected_table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,10 @@ def prepare_queries(
else:
db_schema = config.schema
transactions = (
f"{manifest.get_study_prefix()}"
f"__{enums.ProtectedTables.TRANSACTIONS.value}"
f"{manifest.get_study_prefix()}" f"__{enums.ProtectedTables.TRANSACTIONS.value}"
)
statistics = (
f"{manifest.get_study_prefix()}"
f"__{enums.ProtectedTables.STATISTICS.value}"
f"{manifest.get_study_prefix()}" f"__{enums.ProtectedTables.STATISTICS.value}"
)
self.queries.append(
base_templates.get_ctas_empty_query(
Expand Down
12 changes: 3 additions & 9 deletions cumulus_library/schema/typesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@ class Structure(FHIR):

Patient = "http://hl7.org/fhir/us/core/StructureDefinition/us-core-patient"
Encounter = "http://hl7.org/fhir/us/core/StructureDefinition/us-core-encounter"
DocumentReference = (
"http://hl7.org/fhir/us/core/StructureDefinition/us-core-documentreference"
)
DocumentReference = "http://hl7.org/fhir/us/core/StructureDefinition/us-core-documentreference"
Condition = "http://hl7.org/fhir/condition-definitions.html"
ObservationLab = (
"http://hl7.org/fhir/us/core/StructureDefinition/us-core-observation-lab"
)
ObservationValue = (
"http://hl7.org/fhir/observation-definitions.html#Observation.value_x_"
)
ObservationLab = "http://hl7.org/fhir/us/core/StructureDefinition/us-core-observation-lab"
ObservationValue = "http://hl7.org/fhir/observation-definitions.html#Observation.value_x_"
VitalSign = "http://hl7.org/fhir/us/vitals/ImplementationGuide/hl7.fhir.us.vitals"


Expand Down
4 changes: 1 addition & 3 deletions cumulus_library/schema/valueset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ class ValueSet(Enum):
DocumentType = "http://hl7.org/fhir/ValueSet/c80-doc-typecodes"
ObservationCode = "http://hl7.org/fhir/ValueSet/observation-codes"
ObservationCategory = "http://hl7.org/fhir/ValueSet/observation-category"
ObservationInterpretation = (
"http://hl7.org/fhir/ValueSet/observation-interpretation"
)
ObservationInterpretation = "http://hl7.org/fhir/ValueSet/observation-interpretation"

def __init__(self, url: str):
self.url = url
Expand Down
16 changes: 4 additions & 12 deletions cumulus_library/statistics/counts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ def get_table_name(self, table_name: str, duration=None) -> str:
else:
return f"{self.study_prefix}__{table_name}"

def get_where_clauses(
self, clause: list | str | None = None, min_subject: int = 10
) -> str:
def get_where_clauses(self, clause: list | str | None = None, min_subject: int = 10) -> str:
"""Convenience method for constructing arbitrary where clauses.
:param clause: either a string or a list of sql where statements
Expand All @@ -55,9 +53,7 @@ def get_where_clauses(
elif isinstance(clause, list):
return clause
else:
raise errors.CountsBuilderError(
f"get_where_clauses invalid clause {clause}"
)
raise errors.CountsBuilderError(f"get_where_clauses invalid clause {clause}")

def get_count_query(
self, table_name: str, source_table: str, table_cols: list, **kwargs
Expand All @@ -84,12 +80,8 @@ def get_count_query(
"fhir_resource",
"filter_resource",
]:
raise errors.CountsBuilderError(
f"count_query received unexpected key: {key}"
)
return counts_templates.get_count_query(
table_name, source_table, table_cols, **kwargs
)
raise errors.CountsBuilderError(f"count_query received unexpected key: {key}")
return counts_templates.get_count_query(table_name, source_table, table_cols, **kwargs)

# ----------------------------------------------------------------------
# The following function all wrap get_count_query as convenience methods.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,9 @@ def get_count_query(
for item in table_cols:
# TODO: remove check after cutover
if isinstance(item, list):
table_col_classed.append(
CountColumn(name=item[0], db_type=item[1], alias=item[2])
)
table_col_classed.append(CountColumn(name=item[0], db_type=item[1], alias=item[2]))
else:
table_col_classed.append(
CountColumn(name=item, db_type="varchar", alias=None)
)
table_col_classed.append(CountColumn(name=item, db_type="varchar", alias=None))
table_cols = table_col_classed

query = base_templates.get_base_template(
Expand Down
Loading

0 comments on commit 88eed26

Please sign in to comment.