Skip to content

Commit

Permalink
ehrdata refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuejohn committed Feb 14, 2024
1 parent d190a36 commit ea72e8a
Show file tree
Hide file tree
Showing 22 changed files with 1,422 additions and 1,518 deletions.
1,409 changes: 0 additions & 1,409 deletions ehrdata.py

This file was deleted.

7 changes: 7 additions & 0 deletions ehrdata/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from importlib.metadata import version

from . import dt, pl, pp, tl, io

__all__ = ["dt", "pl", "pp", "tl", "io"]

__version__ = "0.0.0"
1 change: 1 addition & 0 deletions ehrdata/dt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ehrdata.dt._omop import init_omop
130 changes: 130 additions & 0 deletions ehrdata/dt/_omop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import os


import pandas as pd

import ehrapy as ep
from pathlib import Path
from ehrdata.utils.omop_utils import *
from rich.console import Console
from rich.text import Text
import rich.repr
from rich import print as rprint
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, List




def init_omop(folder_path,
delimiter=None,
make_filename_lowercase=True,
use_dask=False,
level: Literal["stay_level", "patient_level"] = "stay_level",
tables: Union[str, List[str]] = None,
remove_empty_column=True):



filepath_dict = check_with_omop_cdm(folder_path=folder_path, delimiter=delimiter, make_filename_lowercase=make_filename_lowercase)
tables = list(filepath_dict.keys())
adata_dict = {}
adata_dict['filepath_dict'] = filepath_dict
adata_dict['tables'] = tables
adata_dict['delimiter'] = delimiter
adata_dict['use_dask'] = use_dask


table_catalog_dict = get_table_catalog_dict()

color_map = {
'Clinical data': 'blue',
'Health system data': 'green',
'Health economics data': 'red',
'Standardized derived elements': 'magenta',
'Metadata': 'white',
'Vocabulary': 'dark_orange'
}
# Object description
print_str = f'OMOP Database ([red]{os.path.basename(folder_path)}[/]) with {len(tables)} tables.\n'

# Tables information
for key, value in table_catalog_dict.items():
table_list = [table_name for table_name in tables if table_name in value]
if len(table_list) != 0:
print_str = print_str + f"[{color_map[key]}]{key} tables[/]: [black]{', '.join(table_list)}[/]\n"
#table_list_str = ', '.join(table_list)

#text = Text(f"{key} tables: ", style=color_map[key])
#text.append(table_list_str)
#yield None, f"{key} tables", "red"
rprint(print_str)

tables = ['person', 'death', 'visit_occurrence']
# TODO patient level and hospital level
if level == "stay_level":
index = {"visit_occurrence": "visit_occurrence_id", "person": "person_id", "death": "person_id"}
# TODO Only support clinical_tables_columns
table_dict = {}
for table in tables:
print(f"reading table [{table}]")
column_types = get_column_types(adata_dict, table_name=table)
df = read_table(adata_dict, table_name=table, dtype=column_types, index='person_id')
if remove_empty_column:
# TODO dask Support
#columns = [column for column in df.columns if not df[column].compute().isna().all()]
columns = [column for column in df.columns if not df[column].isna().all()]
df = df.loc[:, columns]
table_dict[table] = df

# concept_id_list = list(self.concept.concept_id)
# concept_name_list = list(self.concept.concept_id)
# concept_domain_id_list = list(set(self.concept.domain_id))

# self.loaded_tabel = ['visit_occurrence', 'person', 'death', 'measurement', 'observation', 'drug_exposure']
# TODO dask Support
joined_table = pd.merge(table_dict["visit_occurrence"], table_dict["person"], left_index=True, right_index=True, how="left")

joined_table = pd.merge(joined_table, table_dict["death"], left_index=True, right_index=True, how="left")

# TODO dask Support
#joined_table = joined_table.compute()

# TODO check this earlier
joined_table = joined_table.drop_duplicates(subset='visit_occurrence_id')
joined_table = joined_table.set_index("visit_occurrence_id")
# obs_only_list = list(self.joined_table.columns)
# obs_only_list.remove('visit_occurrence_id')
columns_obs_only = list(set(joined_table.columns) - set(["year_of_birth", "gender_source_value"]))
adata = ep.ad.df_to_anndata(
joined_table, index_column="visit_occurrence_id", columns_obs_only=columns_obs_only
)
# TODO this needs to be fixed because anndata set obs index as string by default
#adata.obs.index = adata.obs.index.astype(int)

"""
for column in self.measurement.columns:
if column != 'visit_occurrence_id':
obs_list = []
for visit_occurrence_id in adata.obs.index:
obs_list.append(list(self.measurement[self.measurement['visit_occurrence_id'] == int(visit_occurrence_id)][column]))
adata.obsm[column]= ak.Array(obs_list)
for column in self.drug_exposure.columns:
if column != 'visit_occurrence_id':
obs_list = []
for visit_occurrence_id in adata.obs.index:
obs_list.append(list(self.drug_exposure[self.drug_exposure['visit_occurrence_id'] == int(visit_occurrence_id)][column]))
adata.obsm[column]= ak.Array(obs_list)
for column in self.observation.columns:
if column != 'visit_occurrence_id':
obs_list = []
for visit_occurrence_id in adata.obs.index:
obs_list.append(list(self.observation[self.observation['visit_occurrence_id'] == int(visit_occurrence_id)][column]))
adata.obsm[column]= ak.Array(obs_list)
"""

adata.uns.update(adata_dict)

return adata

1 change: 1 addition & 0 deletions ehrdata/io/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ehrdata.io._omop import from_dataframe, to_dataframe
55 changes: 55 additions & 0 deletions ehrdata/io/_omop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import List, Union, Literal, Optional
import awkward as ak
import pandas as pd

def from_dataframe(
adata,
feature: str,
df
):
grouped = df.groupby("visit_occurrence_id")
unique_visit_occurrence_ids = set(adata.obs.index)

# Use set difference and intersection more efficiently
feature_ids = unique_visit_occurrence_ids.intersection(grouped.groups.keys())
empty_entry = {source_table_column: [] for source_table_column in set(df.columns) if source_table_column not in ['visit_occurrence_id'] }

# Creating the array more efficiently
ak_array = ak.Array([
grouped.get_group(visit_occurrence_id)[list(set(df.columns) - set(['visit_occurrence_id']))].to_dict(orient='list') if visit_occurrence_id in feature_ids else empty_entry
for visit_occurrence_id in unique_visit_occurrence_ids])
adata.obsm[feature] = ak_array

return adata

# TODO add function to check feature and add concept
# More IO functions

def to_dataframe(
adata,
features: Union[str, List[str]], # TODO also support list of features
# patient str or List, # TODO also support subset of patients/visit
):
# TODO
# can be viewed as patient level - only select some patient
# TODO change variable name here
if isinstance(features, str):
features = [features]
df_concat = pd.DataFrame([])
for feature in features:
df = ak.to_dataframe(adata.obsm[feature])

df.reset_index(drop=False, inplace=True)
df["entry"] = adata.obs.index[df["entry"]]
df = df.rename(columns={"entry": "visit_occurrence_id"})
del df["subentry"]
for col in df.columns:
if col.endswith('time'):
df[col] = pd.to_datetime(df[col])

df['feature_name'] = feature
df_concat = pd.concat([df_concat, df], axis= 0)


return df_concat

1 change: 1 addition & 0 deletions ehrdata/pl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ehrdata.pl._omop import feature_counts
59 changes: 59 additions & 0 deletions ehrdata/pl/_omop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
from ehrdata.tl import get_concept_name
import seaborn as sns
import matplotlib.pyplot as plt

# TODO allow users to pass features
def feature_counts(
adata,
source: Literal[
"observation",
"measurement",
"procedure_occurrence",
"specimen",
"device_exposure",
"drug_exposure",
"condition_occurrence",
],
number=20,
key = None
):

if source == 'measurement':
columns = ["value_as_number", "time", "visit_occurrence_id", "measurement_concept_id"]
elif source == 'observation':
columns = ["value_as_number", "value_as_string", "measurement_datetime"]
elif source == 'condition_occurrence':
columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

filepath_dict = adata.uns['filepath_dict']
tables = adata.uns['tables']

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=[f"{source}_concept_id"])
feature_counts = df_source[f"{source}_concept_id"].value_counts()
if adata.uns['use_dask']:
feature_counts = feature_counts.compute()
feature_counts = feature_counts.to_frame().reset_index(drop=False)[0:number]


feature_counts[f"{source}_concept_id_1"], feature_counts[f"{source}_concept_id_2"] = map_concept_id(
adata.uns, concept_id=feature_counts[f"{source}_concept_id"], verbose=False
)
feature_counts["feature_name"] = get_concept_name(adata, concept_id=feature_counts[f"{source}_concept_id_1"])
if feature_counts[f"{source}_concept_id_1"].equals(feature_counts[f"{source}_concept_id_2"]):
feature_counts.drop(f"{source}_concept_id_2", axis=1, inplace=True)
feature_counts.rename(columns={f"{source}_concept_id_1": f"{source}_concept_id"})
feature_counts = feature_counts.reindex(columns=["feature_name", f"{source}_concept_id", "count"])
else:
feature_counts = feature_counts.reindex(
columns=["feature_name", f"{source}_concept_id_1", f"{source}_concept_id_2", "count"]
)

ax = sns.barplot(feature_counts, x="feature_name", y="count")
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
return feature_counts
1 change: 1 addition & 0 deletions ehrdata/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ehrdata.pp._omop import get_feature_statistics
110 changes: 110 additions & 0 deletions ehrdata/pp/_omop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from typing import List, Union, Literal, Optional
from ehrdata.utils.omop_utils import *
import ehrapy as ep
import warnings

def get_feature_statistics(
adata,
source: Literal[
"observation",
"measurement",
"procedure_occurrence",
"specimen",
"device_exposure",
"drug_exposure",
"condition_occurrence",
],
features: Union[str, int , List[Union[str, int]]] = None,
level="stay_level",
value_col: str = None,
aggregation_methods: Union[Literal["min", "max", "mean", "std", "count"], List[Literal["min", "max", "mean", "std", "count"]]]=None,
add_aggregation_to_X: bool = True,
verbose: bool = False,
use_dask: bool = None,
):
if source in ["measurement", "observation", "specimen"]:
key = f"{source}_concept_id"
elif source in ["device_exposure", "procedure_occurrence", "drug_exposure", "condition_occurrence"]:
key = f"{source.split('_')[0]}_concept_id"
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

if source == 'measurement':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in measurement table could be saved in these columns: value_as_number, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', 'measurement_datetime', key, value_col]
elif source == 'observation':
value_col = 'value_as_number'
warnings.warn(f"Extracting values from {value_col}. Value in observation table could be saved in these columns: value_as_number, value_as_string, value_source_value.\nSpecify value_col to extract value from desired column.")
source_table_columns = ['visit_occurrence_id', "observation_datetime", key, value_col]
elif source == 'condition_occurrence':
source_table_columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")
if isinstance(features, str):
features = [features]
rprint(f"Trying to extarct the following features: {features}")

if use_dask is None:
use_dask = True

column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(adata.uns, table_name=source, dtype=column_types, usecols=source_table_columns, use_dask=use_dask)

info_df = get_feature_info(adata.uns, features=features, verbose=verbose)
info_dict = info_df[['feature_id', 'feature_name']].set_index('feature_id').to_dict()['feature_name']

# Select featrues
df_source = df_source[df_source[key].isin(list(info_df.feature_id))]
#TODO Select time
#da_measurement = da_measurement[(da_measurement.time >= 0) & (da_measurement.time <= 48*60*60)]
#df_source[f'{source}_name'] = df_source[key].map(info_dict)
if aggregation_methods is None:
aggregation_methods = ["min", "max", "mean", "std", "count"]
if level == 'stay_level':
result = df_source.groupby(['visit_occurrence_id', key]).agg({
value_col: aggregation_methods})

if use_dask:
result = result.compute()
result = result.reset_index(drop=False)
result.columns = ["_".join(a) for a in result.columns.to_flat_index()]
result.columns = result.columns.str.removesuffix('_')
result.columns = result.columns.str.removeprefix(f'{value_col}_')
result[f'{source}_name'] = result[key].map(info_dict)

df_statistics = result.pivot(index='visit_occurrence_id',
columns=f'{source}_name',
values=aggregation_methods)
df_statistics.columns = df_statistics.columns.swaplevel()
df_statistics.columns = ["_".join(a) for a in df_statistics.columns.to_flat_index()]


# TODO
sort_columns = True
if sort_columns:
new_column_order = []
for feature in features:
for suffix in (f'_{aggregation_method}' for aggregation_method in aggregation_methods):
col_name = f'{feature}{suffix}'
if col_name in df_statistics.columns:
new_column_order.append(col_name)

df_statistics.columns = new_column_order

df_statistics.index = df_statistics.index.astype(str)

adata.obs = pd.merge(adata.obs, df_statistics, how='left', left_index=True, right_index=True)

if add_aggregation_to_X:
uns = adata.uns
obsm = adata.obsm
varm = adata.varm
layers = adata.layers
adata = ep.ad.move_to_x(adata, list(df_statistics.columns))
adata.uns = uns
adata.obsm = obsm
adata.varm = varm
# It will change
# adata.layers = layers
return adata
1 change: 1 addition & 0 deletions ehrdata/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ehrdata.tl._omop import get_concept_name
Loading

0 comments on commit ea72e8a

Please sign in to comment.