Skip to content

Commit

Permalink
Refactor time interval table query function
Browse files Browse the repository at this point in the history
  • Loading branch information
eroell committed Nov 22, 2024
1 parent 623ebab commit 8b070c4
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 249 deletions.
94 changes: 17 additions & 77 deletions src/ehrdata/io/omop/_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _generate_value_query(data_table: str, data_field_to_keep: Sequence, aggrega
return is_present_query + value_query


def time_interval_table_query_long_format(
def _time_interval_table(
backend_handle: duckdb.duckdb.DuckDBPyConnection,
time_defining_table: str,
data_table: str,
Expand All @@ -124,14 +124,13 @@ def time_interval_table_query_long_format(
num_intervals: int,
aggregation_strategy: str,
data_field_to_keep: Sequence[str] | str,
date_prefix: str = "",
) -> pd.DataFrame:
"""Returns a long format DataFrame from the data_table. The following columns should be considered the indices of this long format: person_id, data_table_concept_id, interval_step. The other columns, except for start_date and end_date, should be considered the values."""
keep_date: str = "",
):
if isinstance(data_field_to_keep, str):
data_field_to_keep = [data_field_to_keep]

if date_prefix == "":
date_prefix = "timepoint"
if keep_date == "":
keep_date = "timepoint"

timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals)

Expand All @@ -146,8 +145,7 @@ def time_interval_table_query_long_format(
# 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table.
# 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates.
# 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into.
df = backend_handle.execute(
f"""
prepare_alias_query = f"""
WITH person_time_defining_table AS ( \
SELECT person.person_id as person_id, {DATA_TABLE_DATE_KEYS["start"][time_defining_table]} as start_date, {DATA_TABLE_DATE_KEYS["end"][time_defining_table]} as end_date \
FROM person \
Expand Down Expand Up @@ -176,79 +174,18 @@ def time_interval_table_query_long_format(
SELECT *, 1 as is_present \
FROM {data_table} \
) \
"""

if keep_date in ["timepoint", "start", "end"]:
select_query = f"""
SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \
FROM long_format_intervals as lfi \
LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS[date_prefix][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \
LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id AND lfi.data_table_concept_id = data_table_with_presence_indicator.{DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS[keep_date][data_table]} BETWEEN lfi.interval_start AND lfi.interval_end \
GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end
"""
).df()

_drop_timedeltas(backend_handle)

return df


def time_interval_table_for_interval_tables_query_long_format(
backend_handle: duckdb.duckdb.DuckDBPyConnection,
time_defining_table: str,
data_table: str,
interval_length_number: int,
interval_length_unit: str,
num_intervals: int,
aggregation_strategy: str,
data_field_to_keep: Sequence[str] | str,
date_prefix: str = "",
) -> pd.DataFrame:
"""Returns a long format DataFrame from the data_table. The following columns should be considered the indices of this long format: person_id, data_table_concept_id, interval_step. The other columns, except for start_date and end_date, should be considered the values."""
if isinstance(data_field_to_keep, str):
data_field_to_keep = [data_field_to_keep]

if date_prefix != "":
date_prefix = date_prefix + "_"

timedeltas_dataframe = _generate_timedeltas(interval_length_number, interval_length_unit, num_intervals)

_write_timedeltas_to_db(
backend_handle,
timedeltas_dataframe,
)

# multi-step query
# 1. Create person_time_defining_table, which matches the one created for obs. Needs to contain the person_id, and the start date in particular.
# 2. Create person_data_table (data_table is typically measurement), which contains the cross product of person_id and the distinct concept_id s.
# 3. Create long_format_backbone, which is the left join of person_time_defining_table and person_data_table.
# 4. Create long_format_intervals, which is the cross product of long_format_backbone and timedeltas. This table contains most notably the person_id, the concept_id, the interval start and end dates.
# 5. Create the final table, which is the join with the data_table (typically measurement); each measurement is assigned to its person_id, its concept_id, and the interval it fits into.
df = backend_handle.execute(
f"""
WITH person_time_defining_table AS ( \
SELECT person.person_id as person_id, {DATA_TABLE_DATE_KEYS["start"][time_defining_table]} as start_date, {DATA_TABLE_DATE_KEYS["end"][time_defining_table]} as end_date \
FROM person \
JOIN {time_defining_table} ON person.person_id = {time_defining_table}.{TIME_DEFINING_TABLE_SUBJECT_KEY[time_defining_table]} \
), \
person_data_table AS( \
WITH distinct_data_table_concept_ids AS ( \
SELECT DISTINCT {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id
FROM {data_table} \
)
SELECT person.person_id, {DATA_TABLE_CONCEPT_ID_TRUNK[data_table]}_concept_id as data_table_concept_id \
FROM person \
CROSS JOIN distinct_data_table_concept_ids \
), \
long_format_backbone as ( \
SELECT person_time_defining_table.person_id, data_table_concept_id, start_date, end_date \
FROM person_time_defining_table \
LEFT JOIN person_data_table USING(person_id)\
), \
long_format_intervals as ( \
SELECT person_id, data_table_concept_id, interval_step, start_date, start_date + interval_start_offset as interval_start, start_date + interval_end_offset as interval_end \
FROM long_format_backbone \
CROSS JOIN timedeltas \
), \
data_table_with_presence_indicator as( \
SELECT *, 1 as is_present \
FROM {data_table} \
) \
elif keep_date == "interval":
select_query = f"""
SELECT lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end, {_generate_value_query("data_table_with_presence_indicator", data_field_to_keep, AGGREGATION_STRATEGY_KEY[aggregation_strategy])} \
FROM long_format_intervals as lfi \
LEFT JOIN data_table_with_presence_indicator ON lfi.person_id = data_table_with_presence_indicator.person_id \
Expand All @@ -258,7 +195,10 @@ def time_interval_table_for_interval_tables_query_long_format(
OR (data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["start"][data_table]} < lfi.interval_start AND data_table_with_presence_indicator.{DATA_TABLE_DATE_KEYS["end"][data_table]} > lfi.interval_end)) \
GROUP BY lfi.person_id, lfi.data_table_concept_id, interval_step, interval_start, interval_end
"""
).df()

query = prepare_alias_query + select_query

df = backend_handle.execute(query).df()

_drop_timedeltas(backend_handle)

Expand Down
54 changes: 17 additions & 37 deletions src/ehrdata/io/omop/omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@
_check_valid_observation_table,
_check_valid_variable_data_tables,
)
from ehrdata.io.omop._queries import (
time_interval_table_for_interval_tables_query_long_format,
time_interval_table_query_long_format,
)
from ehrdata.io.omop._queries import _time_interval_table
from ehrdata.utils._omop_utils import get_table_catalog_dict

DOWNLOAD_VERIFICATION_TAG = "download_verification_tag"
Expand Down Expand Up @@ -335,7 +332,7 @@ def setup_variables(
return edata

ds = (
time_interval_table_query_long_format(
_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
Expand Down Expand Up @@ -437,7 +434,7 @@ def setup_interval_variables(
Strategy to use when aggregating multiple data points within one interval.
enrich_var_with_feature_info
Whether to enrich the var table with feature information. If a concept_id is not found in the concept table, the feature information will be NaN.
keep_date
date_type
Whether to keep the start or end date, or the interval span.
Returns
Expand Down Expand Up @@ -469,38 +466,21 @@ def setup_interval_variables(
logging.info(f"No data in {data_tables}.")
return edata

if keep_date == "start" or keep_date == "end":
ds = (
time_interval_table_query_long_format(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
date_prefix=keep_date,
)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
)
elif keep_date == "interval":
ds = (
time_interval_table_for_interval_tables_query_long_format(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
date_prefix=keep_date,
)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
ds = (
_time_interval_table(
backend_handle=backend_handle,
time_defining_table=time_defining_table,
data_table=data_tables[0],
data_field_to_keep=data_field_to_keep,
interval_length_number=interval_length_number,
interval_length_unit=interval_length_unit,
num_intervals=num_intervals,
aggregation_strategy=aggregation_strategy,
keep_date=keep_date,
)
.set_index(["person_id", "data_table_concept_id", "interval_step"])
.to_xarray()
)

var = ds["data_table_concept_id"].to_dataframe()

Expand Down
Loading

0 comments on commit 8b070c4

Please sign in to comment.