diff --git a/pyproject.toml b/pyproject.toml index 489fc06..757633b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ "duckdb", # for debug logging (referenced from the issue template) "session-info", + "xarray", ] optional-dependencies.dev = [ "pre-commit", diff --git a/src/ehrdata/dt/datasets.py b/src/ehrdata/dt/datasets.py index adc50bf..33545be 100644 --- a/src/ehrdata/dt/datasets.py +++ b/src/ehrdata/dt/datasets.py @@ -17,23 +17,36 @@ def _get_table_list() -> list: return flat_table_list -def _set_up_duckdb(path: Path, backend_handle: DuckDBPyConnection) -> None: +def _set_up_duckdb(path: Path, backend_handle: DuckDBPyConnection, prefix: str = "") -> None: tables = _get_table_list() + used_tables = [] missing_tables = [] - for table in tables: - # if path exists lowercse, uppercase, capitalized: - table_path = f"{path}/{table}.csv" - if os.path.exists(table_path): - if table == "measurement": - backend_handle.register( - table, backend_handle.read_csv(f"{path}/{table}.csv", dtype={"measurement_source_value": str}) - ) + unused_files = [] + for file_name in os.listdir(path): + file_name_trunk = file_name.split(".")[0].lower() + + if file_name_trunk in tables or file_name_trunk.replace(prefix, "") in tables: + used_tables.append(file_name_trunk.replace(prefix, "")) + + if file_name_trunk == "measurement": + dtype = {"measurement_source_value": str} else: - backend_handle.register(table, backend_handle.read_csv(f"{path}/{table}.csv")) + dtype = None + + backend_handle.register( + file_name_trunk.replace(prefix, ""), + backend_handle.read_csv(f"{path}/{file_name_trunk}.csv", dtype=dtype), + ) else: - missing_tables.append([table]) + unused_files.append(file_name) + + for table in tables: + if table not in used_tables: + missing_tables.append(table) + print("missing tables: ", missing_tables) + print("unused files: ", unused_files) def mimic_iv_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None: @@ -80,14 +93,14 @@ def mimic_iv_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = N else: print(f"Failed to download the file. Status code: {response.status_code}") return - - return _set_up_duckdb(data_path + "/1_omop_data_csv", backend_handle) + # TODO: capitalization, and lowercase, and containing the name + return _set_up_duckdb(data_path + "/1_omop_data_csv", backend_handle, prefix="2b_") def gibleed_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None: - """Loads the GIBleed dataset. + """Loads the GIBleed dataset in the OMOP Common Data model. - More details: https://github.com/OHDSI/EunomiaDatasets. + More details: https://github.com/OHDSI/EunomiaDatasets/tree/main/datasets/GiBleed. Parameters ---------- @@ -109,13 +122,38 @@ def gibleed_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = No >>> ed.dt.gibleed_omop(backend_handle=con) >>> con.execute("SHOW TABLES;").fetchall() """ - # TODO: - # https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/GiBleed - raise NotImplementedError() + if data_path is None: + data_path = Path("ehrapy_data/GIBleed_dataset") + + if data_path.exists(): + print(f"Path to data exists, load tables from there: {data_path}") + else: + print("Downloading data...") + URL = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/GiBleed/GiBleed_5.3.zip" + response = requests.get(URL) + + if response.status_code == 200: + # extract_path = data_path / "gibleed_data_csv" + # extract_path.mkdir(parents=True, exist_ok=True) + + # Use zipfile and io to open the ZIP file in memory + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + # Extract all contents of the ZIP file into the correct subdirectory + z.extractall(data_path) # Extracting to 'extract_path' + print(f"Download successful. ZIP file downloaded and extracted successfully to {data_path}.") + + else: + print(f"Failed to download the file. Status code: {response.status_code}") + + # extracted_folder = next(data_path.iterdir(), data_path) + # extracted_folder = next((folder for folder in data_path.iterdir() if folder.is_dir() and "_csv" in folder.name and "__MACOSX" not in folder.name), data_path) + return _set_up_duckdb(data_path / "GiBleed_5.3", backend_handle) def synthea27nj_omop(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None: - """Loads the Synthe27Nj dataset. + """Loads the Synthea27NJ dataset in the OMOP Common Data model. + + More details: https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/Synthea27Nj. Parameters ---------- @@ -137,9 +175,39 @@ def synthea27nj_omop(backend_handle: DuckDBPyConnection, data_path: Path | None >>> ed.dt.synthea27nj_omop(backend_handle=con) >>> con.execute("SHOW TABLES;").fetchall() """ - # TODO - # https://github.com/darwin-eu/EunomiaDatasets/tree/main/datasets/Synthea27Nj - raise NotImplementedError() + if data_path is None: + data_path = Path("ehrapy_data/Synthea27Nj") + + if data_path.exists(): + print(f"Path to data exists, load tables from there: {data_path}") + else: + print("Downloading data...") + URL = "https://github.com/OHDSI/EunomiaDatasets/raw/main/datasets/Synthea27Nj/Synthea27Nj_5.4.zip" + response = requests.get(URL) + + if response.status_code == 200: + extract_path = data_path / "synthea27nj_omop_csv" + extract_path.mkdir(parents=True, exist_ok=True) + + # Use zipfile and io to open the ZIP file in memory + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + # Extract all contents of the ZIP file into the correct subdirectory + z.extractall(extract_path) # Extracting to 'extract_path' + print(f"Download successful. ZIP file downloaded and extracted successfully to {extract_path}.") + + else: + print(f"Failed to download the file. Status code: {response.status_code}") + return + + extracted_folder = next( + ( + folder + for folder in data_path.iterdir() + if folder.is_dir() and "_csv" in folder.name and "__MACOSX" not in folder.name + ), + data_path, + ) + return _set_up_duckdb(extracted_folder, backend_handle) def mimic_ii(backend_handle: DuckDBPyConnection, data_path: Path | None = None) -> None: diff --git a/src/ehrdata/io/omop/__init__.py b/src/ehrdata/io/omop/__init__.py index eb3908b..8cd4668 100644 --- a/src/ehrdata/io/omop/__init__.py +++ b/src/ehrdata/io/omop/__init__.py @@ -1,17 +1,19 @@ from .omop import ( - extract_condition_occurrence, - extract_device_exposure, - extract_drug_exposure, - extract_measurement, - extract_note, - extract_observation, - extract_observation_period, - extract_person, - extract_person_observation_period, - extract_procedure_occurrence, - extract_specimen, + get_table, get_time_interval_table, load, + # extract_condition_occurrence, + # extract_device_exposure, + # extract_drug_exposure, + # extract_measurement, + # extract_note, + # extract_observation, + # extract_observation_period, + # extract_person, + # extract_person_observation_period, + # extract_procedure_occurrence, + # extract_specimen, + register_omop_to_db_connection, setup_obs, setup_variables, ) diff --git a/src/ehrdata/io/omop/_queries.py b/src/ehrdata/io/omop/_queries.py new file mode 100644 index 0000000..abdbf80 --- /dev/null +++ b/src/ehrdata/io/omop/_queries.py @@ -0,0 +1,139 @@ +from collections.abc import Sequence + +import duckdb +import pandas as pd + +START_DATE_KEY = { + "visit_occurrence": "visit_start_date", + "observation_period": "observation_period_start_date", + "cohort": "cohort_start_date", +} +END_DATE_KEY = { + "visit_occurrence": "visit_end_date", + "observation_period": "observation_period_end_date", + "cohort": "cohort_end_date", +} +TIME_DEFINING_TABLE_SUBJECT_KEY = { + "visit_occurrence": "person_id", + "observation_period": "person_id", + "cohort": "subject_id", +} + +AGGREGATION_STRATEGY_KEY = { + "last": "LAST", + "first": "FIRST", + "mean": "MEAN", + "median": "MEDIAN", + "mode": "MODE", + "sum": "SUM", + "count": "COUNT", + "min": "MIN", + "max": "MAX", + "std": "STD", +} + + +def _generate_timedeltas(interval_length_number: int, interval_length_unit: str, num_intervals: int) -> pd.DataFrame: + timedeltas_dataframe = pd.DataFrame( + { + "interval_start_offset": [ + pd.to_timedelta(i * interval_length_number, interval_length_unit) for i in range(num_intervals) + ], + "interval_end_offset": [ + pd.to_timedelta(i * interval_length_number, interval_length_unit) for i in range(1, num_intervals + 1) + ], + "interval_step": list(range(num_intervals)), + } + ) + return timedeltas_dataframe + + +def _write_timedeltas_to_db( + backend_handle: duckdb.duckdb.DuckDBPyConnection, + timedeltas_dataframe, +) -> None: + backend_handle.execute("DROP TABLE IF EXISTS timedeltas") + backend_handle.execute( + """ + CREATE TABLE timedeltas ( + interval_start_offset INTERVAL, + interval_end_offset INTERVAL, + interval_step INTEGER + ) + """ + ) + backend_handle.execute("INSERT INTO timedeltas SELECT * FROM timedeltas_dataframe") + + +def _drop_timedeltas(backend_handle: duckdb.duckdb.DuckDBPyConnection): + backend_handle.execute("DROP TABLE IF EXISTS timedeltas") + + +def _generate_value_query(data_table: str, data_field_to_keep: Sequence, aggregation_strategy: str) -> str: + query = f"{', ' .join([f'CASE WHEN COUNT(*) = 0 THEN NULL ELSE {aggregation_strategy}({column}) END AS {column}' for column in data_field_to_keep])}" + return query + + +def time_interval_table_query_long_format( + backend_handle: duckdb.duckdb.DuckDBPyConnection, + time_defining_table: str, + data_table: str, + interval_length_number: int, + interval_length_unit: str, + num_intervals: int, + aggregation_strategy: str, + data_field_to_keep: Sequence[str] | str, +) -> pd.DataFrame: + """Returns a long format DataFrame from the data_table. The following columns should be considered the indices of this long format: person_id, data_table_concept_id, interval_step. The other columns, except for start_date and end_date, should be considered the values.""" + if isinstance(data_field_to_keep, str): + data_field_to_keep = [data_field_to_keep] + + timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals) + + _write_timedeltas_to_db( + backend_handle, + timedeltas_dataframe, + ) + + # multi-step query + # 1. Create person_time_defining_table, which matches the one created for obs. Needs to contain the person_id, and the start date in particular. + # 2. Create person_data_table (data_table is typically measurement), which contains the cross product of person_id and the distinct concept_id s. + # 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table. + # 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates. + # 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into. + df = backend_handle.execute( + f""" + WITH person_time_defining_table AS ( \ + SELECT person.person_id as person_id, {START_DATE_KEY[time_defining_table]} as start_date, {END_DATE_KEY[time_defining_table]} as end_date \ + FROM person \ + JOIN {time_defining_table} ON person.person_id = {time_defining_table}.{TIME_DEFINING_TABLE_SUBJECT_KEY[time_defining_table]} \ + ), \ + person_data_table AS( \ + WITH distinct_data_table_concept_ids AS ( \ + SELECT DISTINCT {data_table}_concept_id + FROM {data_table} \ + ) + SELECT person.person_id, {data_table}_concept_id as data_table_concept_id \ + FROM person \ + CROSS JOIN distinct_data_table_concept_ids \ + ), \ + long_format_backbone as ( \ + SELECT person_time_defining_table.person_id, data_table_concept_id, start_date, end_date \ + FROM person_time_defining_table \ + LEFT JOIN person_data_table USING(person_id)\ + ), \ + long_format_intervals as ( \ + SELECT person_id, data_table_concept_id, interval_step, start_date, start_date + interval_start_offset as interval_start, start_date + interval_end_offset as interval_end \ + FROM long_format_backbone \ + CROSS JOIN timedeltas \ + ) \ + SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query(data_table, data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \ + FROM long_format_intervals as lfi \ + LEFT JOIN {data_table} ON lfi.person_id = {data_table}.person_id AND lfi.data_table_concept_id = {data_table}.{data_table}_concept_id AND {data_table}.{data_table}_date BETWEEN lfi.interval_start AND lfi.interval_end \ + GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end + """ + ).df() + + _drop_timedeltas(backend_handle) + + return df diff --git a/src/ehrdata/io/omop/omop.py b/src/ehrdata/io/omop/omop.py index 844a107..6034b17 100644 --- a/src/ehrdata/io/omop/omop.py +++ b/src/ehrdata/io/omop/omop.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from collections.abc import Sequence from pathlib import Path from typing import Literal @@ -9,6 +10,16 @@ import numpy as np import pandas as pd +from ehrdata.io.omop._queries import ( + AGGREGATION_STRATEGY_KEY, + time_interval_table_query_long_format, +) +from ehrdata.utils._omop_utils import get_omop_table_names + +VALID_OBSERVATION_TABLES_SINGLE = ["person"] +VALID_OBSERVATION_TABLES_JOIN = ["person_cohort", "person_observation_period", "person_visit_occurrence"] +VALID_VARIABLE_TABLES = ["measurement", "observation", "specimen"] + def _check_sanity_of_folder(folder_path: str | Path): pass @@ -18,15 +29,195 @@ def _check_sanity_of_database(backend_handle: duckdb.DuckDB): pass +def _check_valid_backend_handle(backend_handle) -> None: + if not isinstance(backend_handle, duckdb.duckdb.DuckDBPyConnection): + raise TypeError("Expected backend_handle to be of type DuckDBPyConnection.") + + +def _check_valid_observation_table(observation_table) -> None: + if not isinstance(observation_table, str): + raise TypeError("Expected observation_table to be a string.") + if observation_table not in VALID_OBSERVATION_TABLES_SINGLE + VALID_OBSERVATION_TABLES_JOIN: + raise ValueError( + f"observation_table must be one of {VALID_OBSERVATION_TABLES_SINGLE+VALID_OBSERVATION_TABLES_JOIN}." + ) + + +def _check_valid_death_table(death_table) -> None: + if not isinstance(death_table, bool): + raise TypeError("Expected death_table to be a boolean.") + + +def _check_valid_edata(edata) -> None: + from ehrdata import EHRData + + if not isinstance(edata, EHRData): + raise TypeError("Expected edata to be of type EHRData.") + + +def _check_valid_data_tables(data_tables) -> Sequence: + if isinstance(data_tables, str): + data_tables = [data_tables] + if not isinstance(data_tables, Sequence): + raise TypeError("Expected data_tables to be a string or Sequence.") + if not all(table in VALID_VARIABLE_TABLES for table in data_tables): + raise ValueError(f"data_tables must be a subset of {VALID_VARIABLE_TABLES}.") + return data_tables + + +def _check_valid_data_field_to_keep(data_field_to_keep) -> Sequence: + if isinstance(data_field_to_keep, str): + data_field_to_keep = [data_field_to_keep] + if not isinstance(data_field_to_keep, Sequence): + raise TypeError("Expected data_field_to_keep to be a string or Sequence.") + return data_field_to_keep + + +def _check_valid_interval_length_number(interval_length_number) -> None: + if not isinstance(interval_length_number, int): + raise TypeError("Expected interval_length_number to be an integer.") + + +def _check_valid_interval_length_unit(interval_length_unit) -> None: + # TODO: maybe check if it is a valid unit from pandas.to_timedelta + if not isinstance(interval_length_unit, str): + raise TypeError("Expected interval_length_unit to be a string.") + + +def _check_valid_num_intervals(num_intervals) -> None: + if not isinstance(num_intervals, int): + raise TypeError("Expected num_intervals to be an integer.") + + +def _check_valid_concept_ids(concept_ids) -> None: + if concept_ids != "all" and not isinstance(concept_ids, Sequence): + raise TypeError("concept_ids must be a sequence of integers or 'all'.") + + +def _check_valid_aggregation_strategy(aggregation_strategy) -> None: + if aggregation_strategy not in AGGREGATION_STRATEGY_KEY.keys(): + raise TypeError(f"aggregation_strategy must be one of {AGGREGATION_STRATEGY_KEY.keys()}.") + + +def _check_valid_enrich_var_with_feature_info(enrich_var_with_feature_info) -> None: + if not isinstance(enrich_var_with_feature_info, bool): + raise TypeError("Expected enrich_var_with_feature_info to be a boolean.") + + +def _check_valid_enrich_var_with_unit_info(enrich_var_with_unit_info) -> None: + if not isinstance(enrich_var_with_unit_info, bool): + raise TypeError("Expected enrich_var_with_unit_info to be a boolean.") + + +def _collect_units_per_feature(ds, unit_key="unit_concept_id") -> dict: + feature_units = {} + for i in range(ds[unit_key].shape[1]): + single_feature_units = ds[unit_key].isel({ds[unit_key].dims[1]: i}) + single_feature_units_flat = np.array(single_feature_units).flatten() + single_feature_units_unique = pd.unique(single_feature_units_flat[~pd.isna(single_feature_units_flat)]) + feature_units[ds["data_table_concept_id"][i].item()] = single_feature_units_unique + return feature_units + + +def _check_one_unit_per_feature(ds, unit_key="unit_concept_id") -> None: + feature_units = _collect_units_per_feature(ds, unit_key=unit_key) + num_units = np.array([len(units) for _, units in feature_units.items()]) + + # print(f"no units for features: {np.argwhere(num_units == 0)}") + print(f"multiple units for features: {np.argwhere(num_units > 1)}") + + +def _create_feature_unit_concept_id_report(backend_handle, ds) -> pd.DataFrame: + feature_units_concept = _collect_units_per_feature(ds, unit_key="unit_concept_id") + + feature_units_long_format = [] + for feature, units in feature_units_concept.items(): + if len(units) == 0: + feature_units_long_format.append({"concept_id": feature, "no_units": True, "multiple_units": False}) + elif len(units) > 1: + for unit in units: + feature_units_long_format.append( + { + "concept_id": feature, + "unit_concept_id": unit, + "no_units": False, + "multiple_units": True, + } + ) + else: + feature_units_long_format.append( + { + "concept_id": feature, + "unit_concept_id": units[0], + "no_units": False, + "multiple_units": False, + } + ) + + df = pd.DataFrame( + feature_units_long_format, columns=["concept_id", "unit_concept_id", "no_units", "multiple_units"] + ) + df["unit_concept_id"] = df["unit_concept_id"].astype("Int64") + + return df + + +def _create_enriched_var_with_unit_info(backend_handle, ds, var, unit_report) -> pd.DataFrame: + feature_concept_id_table = var # ds["data_table_concept_id"].to_dataframe() + + feature_concept_id_unit_table = pd.merge( + feature_concept_id_table, unit_report, how="left", left_index=True, right_on="concept_id" + ) + + concepts = backend_handle.sql("SELECT * FROM concept").df() + + feature_concept_id_unit_info_table = pd.merge( + feature_concept_id_unit_table, + concepts, + how="left", + left_on="unit_concept_id", + right_on="concept_id", + ) + + return feature_concept_id_unit_info_table + + +def register_omop_to_db_connection( + path: Path, + backend_handle: duckdb.duckdb.DuckDBPyConnection, + source: Literal["csv"] = "csv", +) -> None: + """Register the OMOP CDM tables to the database.""" + missing_tables = [] + for table in get_omop_table_names(): + # if path exists lowercse, uppercase, capitalized: + table_path = f"{path}/{table}.csv" + if os.path.exists(table_path): + if table == "measurement": + backend_handle.register( + table, backend_handle.read_csv(f"{path}/{table}.csv", dtype={"measurement_source_value": str}) + ) + else: + backend_handle.register(table, backend_handle.read_csv(f"{path}/{table}.csv")) + else: + missing_tables.append([table]) + print("missing tables: ", missing_tables) + + return None + + def setup_obs( backend_handle: Literal[str, duckdb, Path], - observation_table: Literal["person", "observation_period", "person_observation_period", "condition_occurrence"], + observation_table: Literal["person", "person_cohort", "person_observation_period", "person_visit_occurrence"], + death_table: bool = False, ): """Setup the observation table. - This function sets up the observation table for the EHRData project. - For this, a table from the OMOP CDM which represents to observed unit should be selected. - A unit can be a person, an observation period, the join of these two tables, or a condition occurrence. + This function sets up the observation table for the EHRData object. + For this, a table from the OMOP CDM which represents the "observed unit" via an id should be selected. + A unit can be a person, or the data of a person together with either the information from cohort, observation_period, or visit_occurrence. + Notice a single person can have multiple of the latter, and as such can appear multiple times. + For person_cohort, the subject_id of the cohort is considered to be the person_id for a join. Parameters ---------- @@ -34,47 +225,57 @@ def setup_obs( The backend handle to the database. observation_table The observation table to be used. + death_table + Whether to include the death table. The observation_table created will be left joined with the death table as the right table. Returns ------- An EHRData object with populated .obs field. """ + _check_valid_backend_handle(backend_handle) + _check_valid_observation_table(observation_table) + _check_valid_death_table(death_table) + from ehrdata import EHRData - if observation_table == "person": - obs = extract_person(backend_handle) - elif observation_table == "observation_period": - obs = extract_observation_period(backend_handle) - elif observation_table == "person_observation_period": - obs = extract_person_observation_period(backend_handle) - elif observation_table == "condition_occurrence": - obs = extract_condition_occurrence(backend_handle) - else: - raise ValueError("observation_table must be either 'person', 'observation_period', or 'condition_occurrence'.") + if observation_table in VALID_OBSERVATION_TABLES_SINGLE: + obs = get_table(backend_handle, observation_table) + + elif observation_table in VALID_OBSERVATION_TABLES_JOIN: + if observation_table == "person_cohort": + obs = _get_table_join(backend_handle, "person", "cohort", right_key="subject_id") + elif observation_table == "person_observation_period": + obs = _get_table_join(backend_handle, "person", "observation_period") + elif observation_table == "person_visit_occurrence": + obs = _get_table_join(backend_handle, "person", "visit_occurrence") + + if death_table: + death = get_table(backend_handle, "death") + obs = obs.merge(death, how="left", on="person_id") - return EHRData(obs=obs) + return EHRData(obs=obs, uns={"omop_io_observation_table": observation_table.split("person_")[-1]}) def setup_variables( - backend_handle: Literal[str, duckdb, Path], edata, - tables: Sequence[ - Literal[ - "measurement", "observation", "procedure_occurrence", "specimen", "device_exposure", "drug_exposure", "note" - ] - ], - start_time: Literal["observation_period_start"] | pd.Timestamp | str, + *, + backend_handle: duckdb.duckdb.DuckDBPyConnection, + data_tables: Sequence[Literal["measurement", "observation", "specimen"]] + | Literal["measurement", "observation", "specimen"], + data_field_to_keep: str | Sequence[str], interval_length_number: int, interval_length_unit: str, num_intervals: int, concept_ids: Literal["all"] | Sequence = "all", aggregation_strategy: str = "last", + enrich_var_with_feature_info: bool = False, + enrich_var_with_unit_info: bool = False, ): """Setup the variables. - This function sets up the variables for the EHRData project. - For this, a selection of tables from the OMOP CDM which represents the variables should be selected. - The tables can be measurement, observation, procedure_occurrence, specimen, device_exposure, drug_exposure, or note. + This function sets up the variables for the EHRData object. + It will fail if there is more than one unit_concept_id per feature. + Writes a unit report of the features to edata.uns["unit_report_"]. Parameters ---------- @@ -82,70 +283,112 @@ def setup_variables( The backend handle to the database. edata The EHRData object to which the variables should be added. - tables - The tables to be used. + data_tables + The table to be used. Only a single table can be used. + data_field_to_keep + The CDM Field in the data table to be kept. Can be e.g. "value_as_number" or "value_as_concept_id". start_time - Starting time for values to be included. Can be 'observation_period' start, which takes the 'observation_period_start' value from obs, or a specific Timestamp. + Starting time for values to be included. interval_length_number Numeric value of the length of one interval. interval_length_unit - Unit belonging to the interval length. See the units of `pandas.to_timedelta `_ + Unit belonging to the interval length. num_intervals - Numer of intervals + Number of intervals. + concept_ids + Concept IDs to use from this data table. If not specified, 'all' are used. + aggregation_strategy + Strategy to use when aggregating multiple data points within one interval. + enrich_var_with_feature_info + Whether to enrich the var table with feature information. If a concept_id is not found in the concept table, the feature information will be NaN. + enrich_var_with_unit_info + Whether to enrich the var table with unit information. Raises an Error if a) multiple units per feature are found for at least one feature. If a concept_id is not found in the concept table, the feature information will be NaN. Returns ------- - An EHRData object with populated .var field. + An EHRData object with populated .r and .var field. """ from ehrdata import EHRData - concept_ids_present_list = [] - time_interval_tables = [] - for table in tables: - if table == "measurement": - concept_ids_present = ( - backend_handle.sql("SELECT * FROM measurement").df()["measurement_concept_id"].unique() - ) - extracted_awkward = extract_measurement(backend_handle) - time_interval_table = get_time_interval_table( - backend_handle, - extracted_awkward, - edata.obs, - start_time="observation_period_start", - interval_length_number=interval_length_number, - interval_length_unit=interval_length_unit, - num_intervals=num_intervals, - concept_ids=concept_ids, - aggregation_strategy=aggregation_strategy, - ) - # TODO: implement the following - # elif table == "observation": - # var = extract_observation(backend_handle) - # elif table == "procedure_occurrence": - # var = extract_procedure_occurrence(backend_handle) - # elif table == "specimen": - # var = extract_specimen(backend_handle) - # elif table == "device_exposure": - # var = extract_device_exposure(backend_handle) - # elif table == "drug_exposure": - # var = extract_drug_exposure(backend_handle) - # elif table == "note": - # var = extract_note(backend_handle) + _check_valid_edata(edata) + _check_valid_backend_handle(backend_handle) + data_tables = _check_valid_data_tables(data_tables) + data_field_to_keep = _check_valid_data_field_to_keep(data_field_to_keep) + _check_valid_interval_length_number(interval_length_number) + _check_valid_interval_length_unit(interval_length_unit) + _check_valid_num_intervals(num_intervals) + _check_valid_concept_ids(concept_ids) + _check_valid_aggregation_strategy(aggregation_strategy) + _check_valid_enrich_var_with_feature_info(enrich_var_with_feature_info) + _check_valid_enrich_var_with_unit_info(enrich_var_with_unit_info) + + time_defining_table = edata.uns.get("omop_io_observation_table", None) + if time_defining_table is None: + raise ValueError("The observation table must be set up first, use the `setup_obs` function.") + + if data_tables[0] in ["measurement", "observation"]: + # also keep unit_concept_id and unit_source_value; + if isinstance(data_field_to_keep, list): + data_field_to_keep = list(data_field_to_keep) + ["unit_concept_id", "unit_source_value"] + # TODO: use in future version when more than one data table can be used + # elif isinstance(data_field_to_keep, dict): + # data_field_to_keep = { + # k: v + ["unit_concept_id", "unit_source_value"] for k, v in data_field_to_keep.items() + # } else: - raise ValueError( - "tables must be a sequence of 'measurement', 'observation', 'procedure_occurrence', 'specimen', 'device_exposure', 'drug_exposure', or 'note'." + raise ValueError + + ds = ( + time_interval_table_query_long_format( + backend_handle=backend_handle, + time_defining_table=time_defining_table, + data_table=data_tables[0], + data_field_to_keep=data_field_to_keep, + interval_length_number=interval_length_number, + interval_length_unit=interval_length_unit, + num_intervals=num_intervals, + aggregation_strategy=aggregation_strategy, + ) + .set_index(["person_id", "data_table_concept_id", "interval_step"]) + .to_xarray() + ) + + _check_one_unit_per_feature(ds) + # TODO ignore? go with more vanilla omop style. _check_one_unit_per_feature(ds, unit_key="unit_source_value") + + unit_report = _create_feature_unit_concept_id_report(backend_handle, ds) + + var = ds["data_table_concept_id"].to_dataframe() + concepts = backend_handle.sql("SELECT * FROM concept").df() + + if enrich_var_with_feature_info: + var = pd.merge(var, concepts, how="left", left_index=True, right_on="concept_id") + + if enrich_var_with_unit_info: + if unit_report["multiple_units"].sum() > 0: + raise ValueError("Multiple units per feature found. Enrichment with feature information not possible.") + else: + var = pd.merge( + var, + unit_report, + how="left", + left_index=True, + right_on="unit_concept_id", + suffixes=("", "_unit"), + ) + var = pd.merge( + var, + concepts, + how="left", + left_on="unit_concept_id", + right_on="concept_id", + suffixes=("", "_unit"), ) - concept_ids_present_list.append(concept_ids_present) - time_interval_tables.append(time_interval_table) - if len(time_interval_tables) > 1: - time_interval_table = np.concatenate([time_interval_table, time_interval_table], axis=1) - concept_ids_present = pd.concat(concept_ids_present_list) - else: - time_interval_table = time_interval_tables[0] - concept_ids_present = concept_ids_present_list[0] - # TODO: copy other fields too. or other way? is is somewhat scverse-y by taking and returing anndata object... - edata = EHRData(r=time_interval_table, obs=edata.obs, var=concept_ids_present) + t = ds["interval_step"].to_dataframe() + + edata = EHRData(r=ds[data_field_to_keep[0]].values, obs=edata.obs, var=var, uns=edata.uns, t=t) + edata.uns[f"unit_report_{data_tables[0]}"] = unit_report return edata @@ -155,10 +398,6 @@ def load( # folder_path: str, # delimiter: str = ",", # make_filename_lowercase: bool = True, - # use_dask: bool = False, - # level: Literal["stay_level", "patient_level"] = "stay_level", - # load_tables: str | list[str] | tuple[str] | Literal["auto"] | None = None, - # remove_empty_column: bool = True, ) -> None: """Initialize a connection to the OMOP CDM Database.""" if isinstance(backend_handle, str) or isinstance(backend_handle, Path): @@ -169,68 +408,104 @@ def load( raise NotImplementedError(f"Backend {backend_handle} not supported. Choose a valid backend.") -def extract_person(duckdb_instance): - """Extract person table of an OMOP CDM Database.""" - return duckdb_instance.sql("SELECT * FROM person").df() - - -def extract_observation_period(duckdb_instance): - """Extract person table of an OMOP CDM Database.""" - return duckdb_instance.sql("SELECT * FROM observation_period").df() +def get_table(duckdb_instance, table_name: str) -> pd.DataFrame: + """Extract a table of an OMOP CDM Database.""" + return _lowercase_column_names(duckdb_instance.sql(f"SELECT * FROM {table_name}").df()) -def extract_person_observation_period(duckdb_instance): - """Extract observation table of an OMOP CDM Database.""" - return duckdb_instance.sql( - "SELECT * \ - FROM person \ - LEFT JOIN observation_period USING(person_id) \ +def _get_table_join( + duckdb_instance, table1: str, table2: str, left_key: str = "person_id", right_key: str = "person_id" +) -> pd.DataFrame: + """Extract a table of an OMOP CDM Database.""" + return _lowercase_column_names( + duckdb_instance.sql( + f"SELECT * \ + FROM {table1} as t1 \ + JOIN {table2} as t2 ON t1.{left_key} = t2.{right_key} \ " - ).df() - - -def extract_measurement(duckdb_instance=None): - """Extract measurement table of an OMOP CDM Database.""" - measurement_table = duckdb_instance.sql("SELECT * FROM measurement").df() - - # get an array n_person x n_features x 2, one for value, one for time - person_id = ( - duckdb_instance.sql("SELECT * FROM person").df()["person_id"].unique() - ) # TODO: in anndata? w.r.t database? for now this - features = measurement_table["measurement_concept_id"].unique() - person_collection = [] - - for person in person_id: - person_as_list = [] - person_measurements = measurement_table[ - measurement_table["person_id"] == person - ] # or ofc sql in rdbms - lazy, on disk, first step towards huge memory reduction of this prototype if only load this selection - # person_measurements = person_measurements.sort_values(by="measurement_date") - # person_measurements = person_measurements[["measurement_date", "value_as_number"]] - # print(person_measurements) - for feature in features: - person_feature = [] - - # person_measurements_value = [] - # person_measurements_timestamp = [] - - person_feature_measurements = person_measurements["measurement_concept_id"] == feature - - person_feature_measurements_value = person_measurements[person_feature_measurements][ - "value_as_number" - ] # again, rdbms/spark backend big time scalable here - person_feature_measurements_timestamp = person_measurements[person_feature_measurements][ - "measurement_datetime" - ] - - person_feature.append(person_feature_measurements_value) - person_feature.append(person_feature_measurements_timestamp) - - person_as_list.append(person_feature) - - person_collection.append(person_as_list) - - return ak.Array(person_collection) + ).df() + ) + + +def extract_measurement(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + return get_table( + duckdb_instance, + table_name="measurement", + concept_id_col="measurement_concept_id", + value_col="value_as_number", + timestamp_col="measurement_datetime", + ) + + +def extract_observation(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + return get_table( + duckdb_instance, + table_name="observation", + concept_id_col="observation_concept_id", + value_col="value_as_number", + timestamp_col="observation_datetime", + ) + + +def extract_procedure_occurrence(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + return get_table( + duckdb_instance, + table_name="procedure_occurrence", + concept_id_col="procedure_concept_id", + value_col="procedure_type_concept_id", # Assuming `procedure_type_concept_id` is a suitable value field + timestamp_col="procedure_datetime", + ) + + +def extract_specimen(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + return get_table( + duckdb_instance, + table_name="specimen", + concept_id_col="specimen_concept_id", + value_col="unit_concept_id", # Assuming `unit_concept_id` is a suitable value field + timestamp_col="specimen_datetime", + ) + + +def extract_device_exposure(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + # return get_table( + # duckdb_instance, + # table_name="device_exposure", + # concept_id_col="device_concept_id", + # value_col="device_type_concept_id", # Assuming this as value + # timestamp_col="device_exposure_start_date" + # ) + # NEEDS IMPLEMENTATION + return None + + +def extract_drug_exposure(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + # return get_table( + # duckdb_instance, + # table_name="drug_exposure", + # concept_id_col="drug_concept_id", + # value_col="dose_unit_concept_id", # Assuming `dose_unit_concept_id` as value + # timestamp_col="drug_exposure_start_datetime" + # ) + # NEEDS IMPLEMENTATION + return None + + +def extract_note(duckdb_instance): + """Extract a table of an OMOP CDM Database.""" + return get_table( + duckdb_instance, + table_name="note", + concept_id_col="note_type_concept_id", + value_col="note_class_concept_id", # Assuming `note_class_concept_id` as value + timestamp_col="note_datetime", + ) def _get_interval_table_from_awkward_array( @@ -320,11 +595,20 @@ def get_time_interval_table( concept_id_list = concept_ids if num_intervals == "max_observation_duration": + observation_period_df = con.execute("SELECT * from observation_period").df() + observation_period_df = _lowercase_column_names(observation_period_df) + + # Calculate the duration of observation periods num_intervals = np.max( - con.execute("SELECT * from observation_period").df()["observation_period_end_date"] - - con.execute("SELECT * from observation_period").df()["observation_period_start_date"] + observation_period_df["observation_period_end_date"] + - observation_period_df["observation_period_start_date"] ) / pd.to_timedelta(interval_length_number, interval_length_unit) num_intervals = int(np.ceil(num_intervals)) + # num_intervals = np.max( + # con.execute("SELECT * from observation_period").df()["observation_period_end_date"] + # - con.execute("SELECT * from observation_period").df()["observation_period_start_date"] + # ) / pd.to_timedelta(interval_length_number, interval_length_unit) + # num_intervals = int(np.ceil(num_intervals)) tables = [] for person, person_ts in zip(obs.iterrows(), ts, strict=False): @@ -354,36 +638,17 @@ def get_time_interval_table( return np.array(tables).transpose(0, 2, 1) # TODO: store in self, np -def extract_observation(): - """Extract observation table of an OMOP CDM Database.""" - pass - - -def extract_procedure_occurrence(): - """Extract procedure_occurrence table of an OMOP CDM Database.""" - pass - - -def extract_specimen(): - """Extract specimen table of an OMOP CDM Database.""" - pass - - -def extract_device_exposure(): - """Extract device_exposure table of an OMOP CDM Database.""" - pass - - -def extract_drug_exposure(): - """Extract drug_exposure table of an OMOP CDM Database.""" - pass +def _lowercase_column_names(df: pd.DataFrame) -> pd.DataFrame: + """Normalize all column names to lowercase.""" + df.columns = map(str.lower, df.columns) # Convert all column names to lowercase + return df def extract_condition_occurrence(): - """Extract condition_occurrence table of an OMOP CDM Database.""" + """Extract a table of an OMOP CDM Database.""" pass -def extract_note(): - """Extract note table of an OMOP CDM Database.""" +def extract_observation_period(): + """Extract a table of an OMOP CDM Database.""" pass diff --git a/src/ehrdata/utils/_omop_utils.py b/src/ehrdata/utils/_omop_utils.py index 7385538..2b52d02 100644 --- a/src/ehrdata/utils/_omop_utils.py +++ b/src/ehrdata/utils/_omop_utils.py @@ -6,6 +6,7 @@ import os import warnings from pathlib import Path +from typing import Literal # import dask.dataframe as dd import numpy as np @@ -13,8 +14,13 @@ from rich import print as rprint -def get_table_catalog_dict(): - """Get the table catalog dictionary of the OMOP CDM v5.4. +def get_table_catalog_dict(version: Literal["5.4"] = "5.4"): + """Get the table catalog dictionary of the OMOP CDM. + + Parameters + ---------- + version + The version of the OMOP CDM. Currently, only 5.4 is supported. Returns ------- @@ -61,9 +67,32 @@ def get_table_catalog_dict(): "source_to_concept_map", "drug_strength", ] + return table_catalog_dict +def get_omop_table_names(version: Literal["5.4"] = "5.4"): + """Get the table names of the OMOP CDM. + + Args + ---- + version: str, the version of the OMOP CDM. Currently, only 5.4 is supported. + + Returns + ------- + List of table names + """ + if version != "5.4": + raise ValueError("Only support OMOP CDM v5.4!") + + table_catalog_dict = get_table_catalog_dict(version=version) + tables = [] + for _, value_list in table_catalog_dict.items(): + for value in value_list: + tables.append(value) + return tables + + def get_dtype_mapping(): """Get the data type mapping of the OMOP CDM v5.4. diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..8f5fbc0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,12 @@ +import duckdb +import pytest + +from ehrdata.io.omop import register_omop_to_db_connection + + +@pytest.fixture # (scope="session") +def omop_connection_vanilla(): + con = duckdb.connect() + register_omop_to_db_connection(path="tests/data/toy_omop/vanilla", backend_handle=con, source="csv") + yield con + con.close() diff --git a/tests/data/toy_omop/vanilla/cohort.csv b/tests/data/toy_omop/vanilla/cohort.csv new file mode 100644 index 0000000..e9e2ef6 --- /dev/null +++ b/tests/data/toy_omop/vanilla/cohort.csv @@ -0,0 +1,4 @@ +cohort_definition_id,subject_id,cohort_start_date,cohort_end_date +1,1,2100-01-01,2100-01-31 +1,2,2100-01-01,2100-01-31 +1,3,2100-01-01,2100-01-31 diff --git a/tests/data/toy_omop/vanilla/concept.csv b/tests/data/toy_omop/vanilla/concept.csv new file mode 100644 index 0000000..6ca864c --- /dev/null +++ b/tests/data/toy_omop/vanilla/concept.csv @@ -0,0 +1 @@ +concept_id,concept_name,domain_id,vocabulary_id,concept_class_id,standard_concept,concept_code,valid_start_DATE,valid_end_DATE,invalid_reason diff --git a/tests/data/toy_omop/vanilla/death.csv b/tests/data/toy_omop/vanilla/death.csv new file mode 100644 index 0000000..3475d47 --- /dev/null +++ b/tests/data/toy_omop/vanilla/death.csv @@ -0,0 +1,3 @@ +person_id,death_date,death_datetime,death_type_concept_id,cause_concept_id,cause_source_value,cause_source_concept_id +1,2100-03-31,2100-03-31 00:00:00,32817,0,0, +2,2100-03-31,2100-03-31 00:00:00,32817,0,0, diff --git a/tests/data/toy_omop/vanilla/measurement.csv b/tests/data/toy_omop/vanilla/measurement.csv new file mode 100644 index 0000000..222c9a2 --- /dev/null +++ b/tests/data/toy_omop/vanilla/measurement.csv @@ -0,0 +1,10 @@ +measurement_id,person_id,measurement_concept_id,measurement_date,measurement_datetime,measurement_time,measurement_type_concept_id,operator_concept_id,value_as_number,value_as_concept_id,unit_concept_id,range_low,range_high,provider_id,visit_occurrence_id,visit_detail_id,measurement_source_value,measurement_source_concept_id,unit_source_value,value_source_value +1,1,3031147,2100-01-01,2100-01-01 12:00:00,12:00:00,32856,,18,,9557,21,30,,1,,50804,2000001003,mEq/L,18 +2,1,3031147,2100-01-01,2100-01-01 13:00:00,13:00:00,32856,,19,,9557,21,30,,1,,50804,2000001003,mEq/L,19 +3,1,3022318,2100-01-01,2100-01-01 14:00:00,14:00:00,32817,,,45877096,,,,,1,,220048,2000030004,,SR (Sinus Rhythm) +4,2,3031147,2100-01-01,2100-01-01 12:00:00,12:00:00,32856,,20,,9557,21,30,,2,,50804,2000001003,mEq/L,20 +5,2,3031147,2100-01-01,2100-01-01 13:00:00,13:00:00,32856,,21,,9557,21,30,,2,,50804,2000001003,mEq/L,21 +6,2,3022318,2100-01-01,2100-01-01 14:00:00,14:00:00,32817,,,45877096,,,,,2,,220048,2000030004,,SR (Sinus Rhythm) +7,3,3031147,2100-01-01,2100-01-01 12:00:00,12:00:00,32856,,22,,9557,21,30,,3,,50804,2000001003,mEq/L,22 +8,3,3031147,2100-01-01,2100-01-01 13:00:00,13:00:00,32856,,23,,9557,21,30,,3,,50804,2000001003,mEq/L,23 +9,3,3022318,2100-01-01,2100-01-01 14:00:00,14:00:00,32817,,,45883018,,,,,3,,220048,2000030004,,AF (Atrial Fibrillation) diff --git a/tests/data/toy_omop/vanilla/observation.csv b/tests/data/toy_omop/vanilla/observation.csv new file mode 100644 index 0000000..0cd51c2 --- /dev/null +++ b/tests/data/toy_omop/vanilla/observation.csv @@ -0,0 +1,10 @@ +observation_id,person_id,observation_concept_id,observation_date,observation_datetime,observation_type_concept_id,value_as_number,value_as_string,value_as_concept_id,qualifier_concept_id,unit_concept_id,provider_id,visit_occurrence_id,visit_detail_id,observation_source_value,observation_source_concept_id,unit_source_value,qualifier_source_value +1,1,3001062,2100-01-01,2100-01-01 12:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +2,1,3001062,2100-01-01,2100-01-01 13:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +3,1,3034263,2100-01-01,2100-01-01 14:00:00,32817,3,,,,,,,,224409,2000030058,, +4,2,3001062,2100-01-01,2100-01-01 12:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +5,2,3001062,2100-01-01,2100-01-01 13:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +6,2,3034263,2100-01-01,2100-01-01 14:00:00,32817,4,,,,,,,,224409,2000030058,, +7,3,3001062,2100-01-01,2100-01-01 12:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +8,3,3001062,2100-01-01,2100-01-01 13:00:00,32817,,Anemia,0,,,,,,225059,2000030108,, +9,3,3034263,2100-01-01,2100-01-01 14:00:00,32817,5,,,,,,,,224409,2000030058,, diff --git a/tests/data/toy_omop/vanilla/observation_period.csv b/tests/data/toy_omop/vanilla/observation_period.csv new file mode 100644 index 0000000..11df294 --- /dev/null +++ b/tests/data/toy_omop/vanilla/observation_period.csv @@ -0,0 +1,4 @@ +observation_period_id,person_id,observation_period_start_date,observation_period_end_date,period_type_concept_id +1,1,2100-01-01,2100-01-31,32828 +2,2,2100-01-01,2100-01-31,32828 +3,3,2100-01-01,2100-01-31,32828 diff --git a/tests/data/toy_omop/vanilla/person.csv b/tests/data/toy_omop/vanilla/person.csv new file mode 100644 index 0000000..18b89ef --- /dev/null +++ b/tests/data/toy_omop/vanilla/person.csv @@ -0,0 +1,5 @@ +person_id,gender_concept_id,year_of_birth,month_of_birth,day_of_birth,birth_datetime,race_concept_id,ethnicity_concept_id,location_id,provider_id,care_site_id,person_source_value,gender_source_value,gender_source_concept_id,race_source_value,race_source_concept_id,ethnicity_source_value,ethnicity_source_concept_id +1,8507,2095,,,,0,38003563,,,,1234,M,0,,,, +2,8507,2096,,,,0,38003563,,,,1235,M,0,,,, +3,8532,2097,,,,0,0,,,,1236,F,0,,,, +4,8532,2098,,,,0,0,,,,1237,F,0,,,, diff --git a/tests/data/toy_omop/vanilla/visit_occurrence.csv b/tests/data/toy_omop/vanilla/visit_occurrence.csv new file mode 100644 index 0000000..d7b1087 --- /dev/null +++ b/tests/data/toy_omop/vanilla/visit_occurrence.csv @@ -0,0 +1,4 @@ +visit_occurrence_id,person_id,visit_concept_id,visit_start_date,visit_start_datetime,visit_end_date,visit_end_datetime,visit_type_concept_id,provider_id,care_site_id,visit_source_value,visit_source_concept_id,admitting_source_concept_id,admitting_source_value,discharge_to_concept_id,discharge_to_source_value,preceding_visit_occurrence_id +1,1,8870,2100-01-01,2100-01-01 00:00:00,2100-01-31,2100-01-31 00:00:00,,,,10014354|2147-07-08,2000001801,,,,, +2,2,8870,2100-01-01,2100-01-01 00:00:00,2100-01-31,2100-01-31 00:00:00,,,,10014354|2147-07-08,2000001801,,,,, +3,3,8870,2100-01-01,2100-01-01 00:00:00,2100-01-31,2100-01-31 00:00:00,,,,10014354|2147-07-08,2000001801,,,,, diff --git a/tests/test_dt/test_dt.py b/tests/test_dt/test_dt.py new file mode 100644 index 0000000..72fa7a3 --- /dev/null +++ b/tests/test_dt/test_dt.py @@ -0,0 +1,25 @@ +import duckdb + +import ehrdata as ed + + +def test_mimic_iv_omop(): + con = duckdb.connect() + ed.dt.mimic_iv_omop(backend_handle=con) + assert len(con.execute("SHOW TABLES").df()) == 30 + con.close() + + +# TODO +# def test_gibleed_omop(): +# con = duckdb.connect() +# ed.dt.gibleed_omop(backend_handle=con) +# assert len(con.execute("SHOW TABLES").df()) == 36 +# con.close() + + +# def test_synthea27nj_omop(): +# con = duckdb.connect() +# ed.dt.synthea27nj_omop(backend_handle=con) +# assert len(con.execute("SHOW TABLES").df()) == 37 +# con.close() diff --git a/tests/test_io/test_omop.py b/tests/test_io/test_omop.py new file mode 100644 index 0000000..68ed0fc --- /dev/null +++ b/tests/test_io/test_omop.py @@ -0,0 +1,276 @@ +import re + +import pytest + +import ehrdata as ed + +# constants for toy_omop/vanilla +VANILLA_PERSONS_WITH_OBSERVATION_TABLE_ENTRY = { + "person_cohort": 3, + "person_observation_period": 3, + "person_visit_occurrence": 3, +} +VANILLA_NUM_CONCEPTS = { + "measurement": 2, + "observation": 2, +} + +# constants for setup_variables +# only data_table_concept_id +VAR_DIM_BASE = 1 +# number of columns in concept table +NUMBER_COLUMNS_CONCEPT_TABLE = 10 +VAR_DIM_FEATURE_INFO = NUMBER_COLUMNS_CONCEPT_TABLE +# number of columns in concept table + number of columns +NUMBER_COLUMNS_FEATURE_REPORT = 4 +VAR_DIM_UNIT_INFO = NUMBER_COLUMNS_CONCEPT_TABLE + NUMBER_COLUMNS_FEATURE_REPORT + + +@pytest.mark.parametrize( + "observation_table, death_table, expected_length, expected_obs_num_columns", + [ + ("person", False, 4, 18), + ("person", True, 4, 24), + ("person_cohort", False, 3, 22), + ("person_cohort", True, 3, 28), + ("person_observation_period", False, 3, 23), + ("person_observation_period", True, 3, 29), + ("person_visit_occurrence", False, 3, 35), + ("person_visit_occurrence", True, 3, 41), + ], +) +def test_setup_obs(omop_connection_vanilla, observation_table, death_table, expected_length, expected_obs_num_columns): + con = omop_connection_vanilla + edata = ed.io.omop.setup_obs(backend_handle=con, observation_table=observation_table, death_table=death_table) + assert isinstance(edata, ed.EHRData) + + # 4 persons, only 3 are in cohort, or have observation period, or visit occurrence + assert len(edata) == expected_length + assert edata.obs.shape[1] == expected_obs_num_columns + + +@pytest.mark.parametrize( + "backend_handle, observation_table, death_table, expected_error", + [ + ("wrong_type", "person", False, "Expected backend_handle to be of type DuckDBPyConnection."), + (None, 123, False, "Expected observation_table to be a string."), + (None, "person", "wrong_type", "Expected death_table to be a boolean."), + ], +) +def test_setup_obs_illegal_argument_types( + omop_connection_vanilla, + backend_handle, + observation_table, + death_table, + expected_error, +): + with pytest.raises(TypeError, match=expected_error): + ed.io.omop.setup_obs( + backend_handle=backend_handle or omop_connection_vanilla, + observation_table=observation_table, + death_table=death_table, + ) + + +def test_setup_obs_invalid_observation_table_value(omop_connection_vanilla): + con = omop_connection_vanilla + with pytest.raises( + ValueError, + match=re.escape( + "observation_table must be one of ['person', 'person_cohort', 'person_observation_period', 'person_visit_occurrence']." + ), + ): + ed.io.omop.setup_obs(backend_handle=con, observation_table="perso") + + +@pytest.mark.parametrize( + "observation_table", + ["person_cohort", "person_observation_period", "person_visit_occurrence"], +) +@pytest.mark.parametrize( + "data_tables", + [["measurement"], ["observation"]], +) +@pytest.mark.parametrize( + "data_field_to_keep", + [["value_as_number"], ["value_as_concept_id"]], +) +@pytest.mark.parametrize( + "enrich_var_with_feature_info", + [True, False], +) +@pytest.mark.parametrize( + "enrich_var_with_unit_info", + [True, False], +) +def test_setup_variables( + omop_connection_vanilla, + observation_table, + data_tables, + data_field_to_keep, + enrich_var_with_feature_info, + enrich_var_with_unit_info, +): + num_intervals = 4 + con = omop_connection_vanilla + edata = ed.io.omop.setup_obs(backend_handle=con, observation_table=observation_table) + edata = ed.io.omop.setup_variables( + edata, + backend_handle=con, + data_tables=data_tables, + data_field_to_keep=data_field_to_keep, + interval_length_number=1, + interval_length_unit="day", + num_intervals=num_intervals, + enrich_var_with_feature_info=enrich_var_with_feature_info, + enrich_var_with_unit_info=enrich_var_with_unit_info, + ) + + assert isinstance(edata, ed.EHRData) + assert edata.n_obs == VANILLA_PERSONS_WITH_OBSERVATION_TABLE_ENTRY[observation_table] + assert edata.n_vars == VANILLA_NUM_CONCEPTS[data_tables[0]] + assert edata.r.shape[2] == num_intervals + assert edata.var.shape[1] == VAR_DIM_BASE + (VAR_DIM_FEATURE_INFO if enrich_var_with_feature_info else 0) + ( + VAR_DIM_UNIT_INFO if enrich_var_with_unit_info else 0 + ) + + +@pytest.mark.parametrize( + "edata, backend_handle, data_tables, data_field_to_keep, interval_length_number, interval_length_unit, num_intervals, enrich_var_with_feature_info, enrich_var_with_unit_info, expected_error", + [ + ( + "wrong_type", + None, + ["measurement"], + ["value_as_number"], + 1, + "day", + 4, + False, + False, + "Expected edata to be of type EHRData.", + ), + ( + None, + "wrong_type", + ["measurement"], + ["value_as_number"], + 1, + "day", + 4, + False, + False, + "Expected backend_handle to be of type DuckDBPyConnection.", + ), + ( + None, + None, + 123, + ["value_as_number"], + 1, + "day", + 4, + False, + False, + "Expected data_tables to be a string or Sequence.", + ), + ( + None, + None, + ["measurement"], + 123, + 1, + "day", + 4, + False, + False, + "Expected data_field_to_keep to be a string or Sequence.", + ), + ( + None, + None, + ["measurement"], + ["value_as_number"], + "wrong_type", + "day", + 4, + False, + False, + "Expected interval_length_number to be an integer.", + ), + ( + None, + None, + ["measurement"], + ["value_as_number"], + 1, + 123, + 4, + False, + False, + "Expected interval_length_unit to be a string.", + ), + ( + None, + None, + ["measurement"], + ["value_as_number"], + 1, + "day", + "wrong_type", + False, + False, + "Expected num_intervals to be an integer.", + ), + ( + None, + None, + ["measurement"], + ["value_as_number"], + 1, + "day", + 123, + "wrong_type", + False, + "Expected enrich_var_with_feature_info to be a boolean.", + ), + ( + None, + None, + ["measurement"], + ["value_as_number"], + 1, + "day", + 123, + False, + "wrong_type", + "Expected enrich_var_with_unit_info to be a boolean.", + ), + ], +) +def test_setup_variables_illegal_argument_types( + omop_connection_vanilla, + edata, + backend_handle, + data_tables, + data_field_to_keep, + interval_length_number, + interval_length_unit, + num_intervals, + enrich_var_with_feature_info, + enrich_var_with_unit_info, + expected_error, +): + con = omop_connection_vanilla + with pytest.raises(TypeError, match=expected_error): + ed.io.omop.setup_variables( + edata or ed.io.omop.setup_obs(backend_handle=omop_connection_vanilla, observation_table="person_cohort"), + backend_handle=backend_handle or con, + data_tables=data_tables, + data_field_to_keep=data_field_to_keep, + interval_length_number=interval_length_number, + interval_length_unit=interval_length_unit, + num_intervals=num_intervals, + enrich_var_with_feature_info=enrich_var_with_feature_info, + enrich_var_with_unit_info=enrich_var_with_unit_info, + )