Skip to content

Commit

Permalink
improve OMOP loading
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuejohn committed Jan 10, 2024
1 parent 54b6658 commit b951ac2
Showing 1 changed file with 81 additions and 50 deletions.
131 changes: 81 additions & 50 deletions ehrdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,8 @@ def _get_column_types(self, path=None, columns=None):
raise TypeError("Only support CSV and Parquet file!")
columns_lowercase = [column.lower() for column in columns]
for i, column in enumerate(columns_lowercase):
if hasattr(self, "additional_column"):
if column in self.additional_column.keys():
column_types[columns[i]] = self.additional_column[column]

if hasattr(self, "additional_column") and column in self.additional_column.keys():
column_types[columns[i]] = self.additional_column[column]
elif column.endswith(
(
"source_value",
Expand All @@ -247,7 +245,7 @@ def _get_column_types(self, path=None, columns=None):
"lot_number",
)
):
column_types[columns[i]] = str
column_types[columns[i]] = object
# TODO quantity in different tables have different types
elif column.endswith(("as_number", "low", "high", "quantity")):
column_types[columns[i]] = float
Expand All @@ -257,34 +255,51 @@ def _get_column_types(self, path=None, columns=None):
parse_dates.append(columns[i])
elif column.endswith(("id", "birth", "id_1", "id_2", "refills", "days_supply")):
column_types[columns[i]] = "Int64"

else:
raise KeyError(f"{columns[i]} is not defined in OMOP CDM")
if len(parse_dates) == 0:
parse_dates = None
return column_types, parse_dates

def _read_table(self, path, dtype=None, parse_dates=None, index=None, usecols=None, **kwargs):

if not os.path.isfile(path):
folder_walk = os.walk(path)
filetype = next(folder_walk)[2][0].split(".")[-1]
else:
filetype = path.split(".")[-1]
if filetype == 'csv':
def _read_table(self, path, dtype=None, parse_dates=None, index=None, usecols=None, use_dask=False, **kwargs):

if use_dask:
if not os.path.isfile(path):
path = f"{path}/*.csv"
if usecols:
if parse_dates:
parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates}
folder_walk = os.walk(path)
filetype = next(folder_walk)[2][0].split(".")[-1]
else:
filetype = path.split(".")[-1]
if filetype == 'csv':
if not os.path.isfile(path):
path = f"{path}/*.csv"
if usecols:
dtype = {key: dtype[key] for key in usecols if key in dtype}
df = dd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols)
elif filetype == 'parquet':
if not os.path.isfile(path):
path = f"{path}/*.parquet"
df = dd.read_parquet(path, dtype=dtype, parse_dates=parse_dates)
if parse_dates:
parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates}
if usecols:
dtype = {key: dtype[key] for key in usecols if key in dtype}
df = dd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols)
elif filetype == 'parquet':
if not os.path.isfile(path):
path = f"{path}/*.parquet"
df = dd.read_parquet(path, dtype=dtype, parse_dates=parse_dates)
else:
raise TypeError("Only support CSV and Parquet file!")
else:
raise TypeError("Only support CSV and Parquet file!")
if not os.path.isfile(path):
raise TypeError("Only support reading a single file!")
filetype = path.split(".")[-1]
if filetype == 'csv':
if usecols:
if parse_dates:
parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates}
if usecols:
dtype = {key: dtype[key] for key in usecols if key in dtype}
df = pd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols)
elif filetype == 'parquet':
df = pd.read_parquet(path)
else:
raise TypeError("Only support CSV and Parquet file!")

if index:
df = df.set_index(index)
Expand Down Expand Up @@ -376,7 +391,9 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death"
column_types, parse_dates = self._get_column_types(self.filepath[table])
df = self._read_table(self.filepath[table], dtype=column_types, parse_dates = parse_dates, index='person_id')
if remove_empty_column:
columns = [column for column in df.columns if not df[column].compute().isna().all()]
# TODO dask Support
#columns = [column for column in df.columns if not df[column].compute().isna().all()]
columns = [column for column in df.columns if not df[column].isna().all()]
df = df.loc[:, columns]
setattr(self, table, df)

Expand All @@ -387,8 +404,10 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death"
# self.loaded_tabel = ['visit_occurrence', 'person', 'death', 'measurement', 'observation', 'drug_exposure']
joined_table = dd.merge(self.visit_occurrence, self.person, left_index=True, right_index=True, how="left")
joined_table = dd.merge(joined_table, self.death, left_index=True, right_index=True, how="left")

joined_table = joined_table.compute()

# TODO dask Support
#joined_table = joined_table.compute()

joined_table = joined_table.set_index("visit_occurrence_id")

# obs_only_list = list(self.joined_table.columns)
Expand Down Expand Up @@ -445,7 +464,9 @@ def feature_statistics(
):
column_types, parse_dates = self._get_column_types(self.filepath[source])
df_source = self._read_table(self.filepath[source], dtype=column_types, parse_dates = parse_dates, usecols=[f"{source}_concept_id"])
feature_counts = df_source[f"{source}_concept_id"].value_counts().compute()[0:number]
# TODO dask Support
#feature_counts = df_source[f"{source}_concept_id"].value_counts().compute()[0:number]
feature_counts = df_source[f"{source}_concept_id"].value_counts()[0:number]
feature_counts = feature_counts.to_frame().reset_index(drop=False)


Expand Down Expand Up @@ -479,9 +500,9 @@ def map_concept_id(self, concept_id: Union[str, List], verbose=True):
df_concept_relationship = self._read_csv(
self.filepath["concept_relationship"], dtype=column_types, parse_dates=parse_dates
)
df_concept_relationship.compute().dropna(
subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True
) # , usecols=vocabularies_tables_columns["concept_relationship"],
# TODO dask Support
#df_concept_relationship.compute().dropna(subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True) # , usecols=vocabularies_tables_columns["concept_relationship"],
df_concept_relationship.dropna(subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True) # , usecols=vocabularies_tables_columns["concept_relationship"],
concept_relationship_dict = df_to_dict(
df=df_concept_relationship[df_concept_relationship["relationship_id"] == "Maps to"],
key="concept_id_1",
Expand Down Expand Up @@ -523,7 +544,9 @@ def get_concept_name(self, concept_id: Union[str, List], raise_error=False, verb

column_types, parse_dates = self._get_column_types(self.filepath["concept"])
df_concept = self._read_table(self.filepath["concept"], dtype=column_types, parse_dates=parse_dates)
df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
# TODO dask Support
#df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
df_concept.dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"]
concept_dict = df_to_dict(df=df_concept, key="concept_id", value="concept_name")
concept_name = []
concept_name_not_found = []
Expand Down Expand Up @@ -609,23 +632,25 @@ def extract_features(

# TODO load using Dask or Dask-Awkward
# Load source table using dask
column_types, parse_dates = self._get_column_types(self.filepath[source])
source_column_types, parse_dates = self._get_column_types(self.filepath[source])
if parse_dates:
if len(parse_dates) == 1:
columns = list(column_types.keys()) + [parse_dates]
columns = list(source_column_types.keys()) + [parse_dates]
else:
columns = list(column_types.keys()) + parse_dates
columns = list(source_column_types.keys()) + parse_dates
else:
columns = list(column_types.keys())
columns = list(source_column_types.keys())
df_source = self._read_table(
self.filepath[source], dtype=column_types, #parse_dates=parse_dates
self.filepath[source], dtype=source_column_types, #parse_dates=parse_dates
) # , usecols=clinical_tables_columns[source]

if not features:
warnings.warn(
"Please specify desired features you want to extract. Otherwise, it will try to extract all the features!"
)
features = list(df_source[key].compute().unique())
# TODO dask Support
#features = list(df_source[key].compute().unique())
features = list(df_source[key].unique())
else:
if isinstance(features, str):
features = [features]
Expand Down Expand Up @@ -731,7 +756,9 @@ def extract_features(

# for feature_id, feature_name, domain_id, concept_class_id, concept_code in zip(feature_id_list, feature_name_list, domain_id_list, concept_class_id_list, concept_code_list):
try:
feature_df = df_source[df_source[key] == feature_id_2].compute()
# TODO dask Support
#feature_df = df_source[df_source[key] == feature_id_2].compute()
feature_df = df_source[df_source[key] == feature_id_2]
except:
print(f"Features ID could not be found in {source} table")
# TODO add checks if all columns exist in source table
Expand All @@ -740,19 +767,27 @@ def extract_features(

if remove_empty_column:
columns = [column for column in columns if not feature_df[column].isna().all()]
feature_df = feature_df.loc[:, columns]
# TODO
#print()

if len(feature_df) > 0:

# Group by 'visit_occurrence_id' and aggregate the values
grouped = feature_df.groupby("visit_occurrence_id")[columns].agg(list)

# Convert the grouped data to a dictionary
grouped_dict = grouped.to_dict(orient='index')

# Create the final obs_dict
obs_dict = [
{
column: list(feature_df[feature_df["visit_occurrence_id"] == int(visit_occurrence_id)][column])
for column in columns
}
for visit_occurrence_id in adata.obs.index
grouped_dict.get(visit_occurrence_id, {col: [] for col in columns})
for visit_occurrence_id in adata.obs.index.astype(int)
]

adata.obsm[feature_name] = ak.Array(obs_dict)

if add_aggregation_to_X:
unit = feature_df["unit_source_value"].value_counts().index[0]
if aggregation_methods is None:
aggregation_methods = ["min", "max", "mean"]
var_name_list = [
Expand All @@ -761,14 +796,10 @@ def extract_features(
for aggregation_method in aggregation_methods:
func = getattr(ak, aggregation_method)
adata.obs[f"{feature_name}_{aggregation_method}"] = list(
func(adata.obsm[feature_name]["value_source_value"], axis=1)
func(adata.obsm[feature_name][f"{source}_source_value"], axis=1)
)
adata = ep.ad.move_to_x(adata, var_name_list)

adata.var.loc[var_name_list, "Unit"] = unit
adata.var.loc[var_name_list, "domain_id"] = domain_id
adata.var.loc[var_name_list, "concept_class_id"] = concept_class_id
adata.var.loc[var_name_list, "concept_code"] = concept_code

if len(fetures_not_shown_in_concept_table) > 0:
rprint(f"Couldn't find concept {fetures_not_shown_in_concept_table} in concept table!")
Expand Down

0 comments on commit b951ac2

Please sign in to comment.