diff --git a/ehrdata.py b/ehrdata.py index b7f6aee..e8ac9df 100644 --- a/ehrdata.py +++ b/ehrdata.py @@ -221,10 +221,8 @@ def _get_column_types(self, path=None, columns=None): raise TypeError("Only support CSV and Parquet file!") columns_lowercase = [column.lower() for column in columns] for i, column in enumerate(columns_lowercase): - if hasattr(self, "additional_column"): - if column in self.additional_column.keys(): - column_types[columns[i]] = self.additional_column[column] - + if hasattr(self, "additional_column") and column in self.additional_column.keys(): + column_types[columns[i]] = self.additional_column[column] elif column.endswith( ( "source_value", @@ -247,7 +245,7 @@ def _get_column_types(self, path=None, columns=None): "lot_number", ) ): - column_types[columns[i]] = str + column_types[columns[i]] = object # TODO quantity in different tables have different types elif column.endswith(("as_number", "low", "high", "quantity")): column_types[columns[i]] = float @@ -257,34 +255,51 @@ def _get_column_types(self, path=None, columns=None): parse_dates.append(columns[i]) elif column.endswith(("id", "birth", "id_1", "id_2", "refills", "days_supply")): column_types[columns[i]] = "Int64" + else: raise KeyError(f"{columns[i]} is not defined in OMOP CDM") if len(parse_dates) == 0: parse_dates = None return column_types, parse_dates - def _read_table(self, path, dtype=None, parse_dates=None, index=None, usecols=None, **kwargs): - - if not os.path.isfile(path): - folder_walk = os.walk(path) - filetype = next(folder_walk)[2][0].split(".")[-1] - else: - filetype = path.split(".")[-1] - if filetype == 'csv': + def _read_table(self, path, dtype=None, parse_dates=None, index=None, usecols=None, use_dask=False, **kwargs): + + if use_dask: if not os.path.isfile(path): - path = f"{path}/*.csv" - if usecols: - if parse_dates: - parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates} + folder_walk = os.walk(path) + filetype = next(folder_walk)[2][0].split(".")[-1] + else: + filetype = path.split(".")[-1] + if filetype == 'csv': + if not os.path.isfile(path): + path = f"{path}/*.csv" if usecols: - dtype = {key: dtype[key] for key in usecols if key in dtype} - df = dd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols) - elif filetype == 'parquet': - if not os.path.isfile(path): - path = f"{path}/*.parquet" - df = dd.read_parquet(path, dtype=dtype, parse_dates=parse_dates) + if parse_dates: + parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates} + if usecols: + dtype = {key: dtype[key] for key in usecols if key in dtype} + df = dd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols) + elif filetype == 'parquet': + if not os.path.isfile(path): + path = f"{path}/*.parquet" + df = dd.read_parquet(path, dtype=dtype, parse_dates=parse_dates) + else: + raise TypeError("Only support CSV and Parquet file!") else: - raise TypeError("Only support CSV and Parquet file!") + if not os.path.isfile(path): + raise TypeError("Only support reading a single file!") + filetype = path.split(".")[-1] + if filetype == 'csv': + if usecols: + if parse_dates: + parse_dates = {key: parse_dates[key] for key in usecols if key in parse_dates} + if usecols: + dtype = {key: dtype[key] for key in usecols if key in dtype} + df = pd.read_csv(path, delimiter=self.delimiter, dtype=dtype, parse_dates=parse_dates, usecols=usecols) + elif filetype == 'parquet': + df = pd.read_parquet(path) + else: + raise TypeError("Only support CSV and Parquet file!") if index: df = df.set_index(index) @@ -376,7 +391,9 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death" column_types, parse_dates = self._get_column_types(self.filepath[table]) df = self._read_table(self.filepath[table], dtype=column_types, parse_dates = parse_dates, index='person_id') if remove_empty_column: - columns = [column for column in df.columns if not df[column].compute().isna().all()] + # TODO dask Support + #columns = [column for column in df.columns if not df[column].compute().isna().all()] + columns = [column for column in df.columns if not df[column].isna().all()] df = df.loc[:, columns] setattr(self, table, df) @@ -387,8 +404,10 @@ def load(self, level="stay_level", tables=["visit_occurrence", "person", "death" # self.loaded_tabel = ['visit_occurrence', 'person', 'death', 'measurement', 'observation', 'drug_exposure'] joined_table = dd.merge(self.visit_occurrence, self.person, left_index=True, right_index=True, how="left") joined_table = dd.merge(joined_table, self.death, left_index=True, right_index=True, how="left") - - joined_table = joined_table.compute() + + # TODO dask Support + #joined_table = joined_table.compute() + joined_table = joined_table.set_index("visit_occurrence_id") # obs_only_list = list(self.joined_table.columns) @@ -445,7 +464,9 @@ def feature_statistics( ): column_types, parse_dates = self._get_column_types(self.filepath[source]) df_source = self._read_table(self.filepath[source], dtype=column_types, parse_dates = parse_dates, usecols=[f"{source}_concept_id"]) - feature_counts = df_source[f"{source}_concept_id"].value_counts().compute()[0:number] + # TODO dask Support + #feature_counts = df_source[f"{source}_concept_id"].value_counts().compute()[0:number] + feature_counts = df_source[f"{source}_concept_id"].value_counts()[0:number] feature_counts = feature_counts.to_frame().reset_index(drop=False) @@ -479,9 +500,9 @@ def map_concept_id(self, concept_id: Union[str, List], verbose=True): df_concept_relationship = self._read_csv( self.filepath["concept_relationship"], dtype=column_types, parse_dates=parse_dates ) - df_concept_relationship.compute().dropna( - subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True - ) # , usecols=vocabularies_tables_columns["concept_relationship"], + # TODO dask Support + #df_concept_relationship.compute().dropna(subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True) # , usecols=vocabularies_tables_columns["concept_relationship"], + df_concept_relationship.dropna(subset=["concept_id_1", "concept_id_2", "relationship_id"], inplace=True) # , usecols=vocabularies_tables_columns["concept_relationship"], concept_relationship_dict = df_to_dict( df=df_concept_relationship[df_concept_relationship["relationship_id"] == "Maps to"], key="concept_id_1", @@ -523,7 +544,9 @@ def get_concept_name(self, concept_id: Union[str, List], raise_error=False, verb column_types, parse_dates = self._get_column_types(self.filepath["concept"]) df_concept = self._read_table(self.filepath["concept"], dtype=column_types, parse_dates=parse_dates) - df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"] + # TODO dask Support + #df_concept.compute().dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"] + df_concept.dropna(subset=["concept_id", "concept_name"], inplace=True, ignore_index=True) # usecols=vocabularies_tables_columns["concept"] concept_dict = df_to_dict(df=df_concept, key="concept_id", value="concept_name") concept_name = [] concept_name_not_found = [] @@ -609,23 +632,25 @@ def extract_features( # TODO load using Dask or Dask-Awkward # Load source table using dask - column_types, parse_dates = self._get_column_types(self.filepath[source]) + source_column_types, parse_dates = self._get_column_types(self.filepath[source]) if parse_dates: if len(parse_dates) == 1: - columns = list(column_types.keys()) + [parse_dates] + columns = list(source_column_types.keys()) + [parse_dates] else: - columns = list(column_types.keys()) + parse_dates + columns = list(source_column_types.keys()) + parse_dates else: - columns = list(column_types.keys()) + columns = list(source_column_types.keys()) df_source = self._read_table( - self.filepath[source], dtype=column_types, #parse_dates=parse_dates + self.filepath[source], dtype=source_column_types, #parse_dates=parse_dates ) # , usecols=clinical_tables_columns[source] if not features: warnings.warn( "Please specify desired features you want to extract. Otherwise, it will try to extract all the features!" ) - features = list(df_source[key].compute().unique()) + # TODO dask Support + #features = list(df_source[key].compute().unique()) + features = list(df_source[key].unique()) else: if isinstance(features, str): features = [features] @@ -731,7 +756,9 @@ def extract_features( # for feature_id, feature_name, domain_id, concept_class_id, concept_code in zip(feature_id_list, feature_name_list, domain_id_list, concept_class_id_list, concept_code_list): try: - feature_df = df_source[df_source[key] == feature_id_2].compute() + # TODO dask Support + #feature_df = df_source[df_source[key] == feature_id_2].compute() + feature_df = df_source[df_source[key] == feature_id_2] except: print(f"Features ID could not be found in {source} table") # TODO add checks if all columns exist in source table @@ -740,19 +767,27 @@ def extract_features( if remove_empty_column: columns = [column for column in columns if not feature_df[column].isna().all()] + feature_df = feature_df.loc[:, columns] + # TODO + #print() if len(feature_df) > 0: + + # Group by 'visit_occurrence_id' and aggregate the values + grouped = feature_df.groupby("visit_occurrence_id")[columns].agg(list) + + # Convert the grouped data to a dictionary + grouped_dict = grouped.to_dict(orient='index') + + # Create the final obs_dict obs_dict = [ - { - column: list(feature_df[feature_df["visit_occurrence_id"] == int(visit_occurrence_id)][column]) - for column in columns - } - for visit_occurrence_id in adata.obs.index + grouped_dict.get(visit_occurrence_id, {col: [] for col in columns}) + for visit_occurrence_id in adata.obs.index.astype(int) ] + adata.obsm[feature_name] = ak.Array(obs_dict) if add_aggregation_to_X: - unit = feature_df["unit_source_value"].value_counts().index[0] if aggregation_methods is None: aggregation_methods = ["min", "max", "mean"] var_name_list = [ @@ -761,14 +796,10 @@ def extract_features( for aggregation_method in aggregation_methods: func = getattr(ak, aggregation_method) adata.obs[f"{feature_name}_{aggregation_method}"] = list( - func(adata.obsm[feature_name]["value_source_value"], axis=1) + func(adata.obsm[feature_name][f"{source}_source_value"], axis=1) ) adata = ep.ad.move_to_x(adata, var_name_list) - adata.var.loc[var_name_list, "Unit"] = unit - adata.var.loc[var_name_list, "domain_id"] = domain_id - adata.var.loc[var_name_list, "concept_class_id"] = concept_class_id - adata.var.loc[var_name_list, "concept_code"] = concept_code if len(fetures_not_shown_in_concept_table) > 0: rprint(f"Couldn't find concept {fetures_not_shown_in_concept_table} in concept table!")