Skip to content

Commit

Permalink
move all previous functions to corresponding modules
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyuejohn committed Feb 21, 2024
1 parent bfecc2b commit 4a9e960
Show file tree
Hide file tree
Showing 6 changed files with 583 additions and 12 deletions.
2 changes: 1 addition & 1 deletion ehrdata/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ehrdata.io._omop import from_dataframe, init_omop, to_dataframe
from ehrdata.io._omop import extract_features, from_dataframe, init_omop, to_dataframe
169 changes: 163 additions & 6 deletions ehrdata/io/_omop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,43 @@
import pandas as pd
from rich import print as rprint

from ehrdata.utils._omop_utils import check_with_omop_cdm, get_column_types, get_table_catalog_dict, read_table
from ehrdata.utils._omop_utils import (
check_with_omop_cdm,
get_column_types,
get_feature_info,
get_table_catalog_dict,
read_table,
)


def init_omop(
folder_path,
delimiter=None,
make_filename_lowercase=True,
use_dask=False,
folder_path: str,
delimiter: str = None,
make_filename_lowercase: bool = True,
use_dask: bool = False,
level: Literal["stay_level", "patient_level"] = "stay_level",
load_tables: Optional[Union[str, list[str], tuple[str], Literal["auto"]]] = ("visit_occurrence", "person"),
remove_empty_column=True,
remove_empty_column: bool = True,
):
"""_summary_
Args:
folder_path (str): _description_
delimiter (str, optional): _description_. Defaults to None.
make_filename_lowercase (bool, optional): _description_. Defaults to True.
use_dask (bool, optional): _description_. Defaults to False.
level (Literal["stay_level", "patient_level"], optional): _description_. Defaults to "stay_level".
load_tables (Optional[Union[str, list[str], tuple[str], Literal["auto"]]], optional): _description_. Defaults to ("visit_occurrence", "person").
remove_empty_column (bool, optional): _description_. Defaults to True.
Raises
------
ValueError: _description_
Returns
-------
_type_: _description_
"""
filepath_dict = check_with_omop_cdm(
folder_path=folder_path, delimiter=delimiter, make_filename_lowercase=make_filename_lowercase
)
Expand Down Expand Up @@ -249,6 +274,138 @@ def init_omop(
return adata


def extract_features(
adata,
source: Literal[
"observation",
"measurement",
"procedure_occurrence",
"specimen",
"device_exposure",
"drug_exposure",
"condition_occurrence",
],
features: Union[str, int, list[Union[str, int]]] = None,
source_table_columns: Union[str, list[str]] = None,
dropna: Optional[bool] = True,
verbose: Optional[bool] = True,
use_dask: bool = None,
):

if source in ["measurement", "observation", "specimen"]:
key = f"{source}_concept_id"
elif source in ["device_exposure", "procedure_occurrence", "drug_exposure", "condition_occurrence"]:
key = f"{source.split('_')[0]}_concept_id"
else:
raise KeyError(f"Extracting data from {source} is not supported yet")

if source_table_columns is None:
if source == "measurement":
source_table_columns = ["visit_occurrence_id", "measurement_datetime", "value_as_number", key]
elif source == "observation":
source_table_columns = [
"visit_occurrence_id",
"value_as_number",
"value_as_string",
"observation_datetime",
key,
]
elif source == "condition_occurrence":
source_table_columns = None
else:
raise KeyError(f"Extracting data from {source} is not supported yet")
if use_dask is None:
use_dask = use_dask = adata.uns["use_dask"]
# TODO load using Dask or Dask-Awkward
# Load source table using dask
source_column_types = get_column_types(adata.uns, table_name=source)
df_source = read_table(
adata.uns, table_name=source, dtype=source_column_types, usecols=source_table_columns, use_dask=use_dask
)
info_df = get_feature_info(adata.uns, features=features, verbose=False)
info_dict = info_df[["feature_id", "feature_name"]].set_index("feature_id").to_dict()["feature_name"]

# Select featrues
df_source = df_source[df_source[key].isin(list(info_df.feature_id))]

# TODO select time period
# df_source = df_source[(df_source.time >= 0) & (df_source.time <= 48*60*60)]
# da_measurement['measurement_name'] = da_measurement.measurement_concept_id.replace(info_dict)

# TODO dask caching
"""
from dask.cache import Cache
cache = Cache(2e9)
cache.register()
"""
if use_dask:
if dropna:
df_source = df_source.compute().dropna()
else:
df_source = df_source.compute()
else:
if dropna:
df_source = df_source.dropna()

# Preprocess steps outside the loop
unique_visit_occurrence_ids = set(adata.obs.index.astype(int))
empty_entry = {
source_table_column: []
for source_table_column in source_table_columns
if source_table_column not in [key, "visit_occurrence_id"]
}

# Filter data once, if possible
filtered_data = {feature_id: df_source[df_source[key] == feature_id] for feature_id in set(info_dict.keys())}

for feature_id in set(info_dict.keys()):
df_feature = filtered_data[feature_id][list(set(source_table_columns) - {key})]
grouped = df_feature.groupby("visit_occurrence_id")
if verbose:
print(f"Adding feature [{info_dict[feature_id]}] into adata.obsm")

# Use set difference and intersection more efficiently
feature_ids = unique_visit_occurrence_ids.intersection(grouped.groups.keys())

# Creating the array more efficiently
adata.obsm[info_dict[feature_id]] = ak.Array(
[
(
grouped.get_group(visit_occurrence_id)[
list(set(source_table_columns) - {key, "visit_occurrence_id"})
].to_dict(orient="list")
if visit_occurrence_id in feature_ids
else empty_entry
)
for visit_occurrence_id in unique_visit_occurrence_ids
]
)

return adata


def extract_note(
adata,
use_dask: bool = None,
columns: Optional[list[str]] = None,
):
if use_dask is None:
use_dask = use_dask = adata.uns["use_dask"]
source_column_types = get_column_types(adata.uns, table_name="note")
df_source = read_table(adata.uns, table_name="note", dtype=source_column_types, use_dask=use_dask)
if columns is None:
columns = df_source.columns
obs_dict = [
{
column: list(df_source[df_source["visit_occurrence_id"] == int(visit_occurrence_id)][column])
for column in columns
}
for visit_occurrence_id in adata.obs.index
]
adata.obsm["note"] = ak.Array(obs_dict)
return adata


def from_dataframe(adata, feature: str, df):
grouped = df.groupby("visit_occurrence_id")
unique_visit_occurrence_ids = set(adata.obs.index)
Expand Down
170 changes: 168 additions & 2 deletions ehrdata/pl/_omop.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
from typing import Literal
from collections.abc import Sequence
from functools import partial
from typing import Literal, Optional, Union

import ehrapy as ep
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
from anndata import AnnData
from matplotlib.axes import Axes

from ehrdata.io._omop import to_dataframe
from ehrdata.tl import get_concept_name
from ehrdata.utils._omop_utils import get_column_types, map_concept_id, read_table


# TODO allow users to pass features
def feature_counts(
adata,
source: Literal[
Expand Down Expand Up @@ -55,3 +61,163 @@ def feature_counts(
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha="right")
plt.tight_layout()
return feature_counts


def plot_timeseries(
adata: AnnData,
visit_occurrence_id: int,
key: Union[str, list[str]],
slot: Union[str, None] = "obsm",
value_key: str = "value_as_number",
time_key: str = "measurement_datetime",
x_label: str = None,
):
if isinstance(key, str):
key_list = [key]
else:
key_list = key

# Initialize min_x and max_x
min_x = None
max_x = None

if slot == "obsm":
_, ax = plt.subplots(figsize=(20, 6))
# Scatter plot
for key in key_list:
df = to_dataframe(adata, key)
x = df[df.visit_occurrence_id == visit_occurrence_id][time_key]
y = df[df.visit_occurrence_id == visit_occurrence_id][value_key]

# Check if x is empty
if not x.empty:
ax.scatter(x=x, y=y, label=key)
ax.legend(loc=9, bbox_to_anchor=(0.5, -0.1), ncol=len(key_list), prop={"size": 14})

ax.plot(x, y)

if min_x is None or min_x > x.min():
min_x = x.min()
if max_x is None or max_x < x.max():
max_x = x.max()

else:
# Skip this iteration if x is empty
continue

if min_x is not None and max_x is not None:
# Adapt this to input data
# TODO step
# plt.xticks(np.arange(min_x, max_x, step=1))
# Adapt this to input data
plt.xlabel(x_label if x_label else "Hours since ICU admission")

plt.show()


def violin(
adata: AnnData,
obsm_key: str = None,
keys: Union[str, Sequence[str]] = None,
groupby: Optional[str] = None,
log: Optional[bool] = False,
use_raw: Optional[bool] = None,
stripplot: bool = True,
jitter: Union[float, bool] = True,
size: int = 1,
layer: Optional[str] = None,
scale: Literal["area", "count", "width"] = "width",
order: Optional[Sequence[str]] = None,
multi_panel: Optional[bool] = None,
xlabel: str = "",
ylabel: Union[str, Sequence[str]] = None,
rotation: Optional[float] = None,
show: Optional[bool] = None,
save: Union[bool, str] = None,
ax: Optional[Axes] = None,
**kwds,
): # pragma: no cover
"""Violin plot.
Wraps :func:`seaborn.violinplot` for :class:`~anndata.AnnData`.
Args:
adata: :class:`~anndata.AnnData` object object containing all observations.
keys: Keys for accessing variables of `.var_names` or fields of `.obs`.
groupby: The key of the observation grouping to consider.
log: Plot on logarithmic axis.
use_raw: Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.
stripplot: Add a stripplot on top of the violin plot. See :func:`~seaborn.stripplot`.
jitter: Add jitter to the stripplot (only when stripplot is True) See :func:`~seaborn.stripplot`.
size: Size of the jitter points.
layer: Name of the AnnData object layer that wants to be plotted. By
default adata.raw.X is plotted. If `use_raw=False` is set,
then `adata.X` is plotted. If `layer` is set to a valid layer name,
then the layer is plotted. `layer` takes precedence over `use_raw`.
scale: The method used to scale the width of each violin.
If 'width' (the default), each violin will have the same width.
If 'area', each violin will have the same area.
If 'count', a violin’s width corresponds to the number of observations.
order: Order in which to show the categories.
multi_panel: Display keys in multiple panels also when `groupby is not None`.
xlabel: Label of the x axis. Defaults to `groupby` if `rotation` is `None`, otherwise, no label is shown.
ylabel: Label of the y axis. If `None` and `groupby` is `None`, defaults to `'value'`.
If `None` and `groubpy` is not `None`, defaults to `keys`.
rotation: Rotation of xtick labels.
{show_save_ax}
**kwds:
Are passed to :func:`~seaborn.violinplot`.
Returns
-------
A :class:`~matplotlib.axes.Axes` object if `ax` is `None` else `None`.
Example:
.. code-block:: python
import ehrapy as ep
adata = ep.dt.mimic_2(encoded=True)
ep.pp.knn_impute(adata)
ep.pp.log_norm(adata, offset=1)
ep.pp.neighbors(adata)
ep.tl.leiden(adata, resolution=0.5, key_added="leiden_0_5")
ep.pl.violin(adata, keys=["age"], groupby="leiden_0_5")
Preview:
.. image:: /_static/docstring_previews/violin.png
"""
if obsm_key:
df = to_dataframe(adata, features=obsm_key)
df = df[["visit_occurrence_id", "value_as_number"]]
df = df.rename(columns={"value_as_number": obsm_key})

if groupby:
df = df.set_index("visit_occurrence_id").join(adata.obs[groupby].to_frame()).reset_index(drop=False)
adata = ep.ad.df_to_anndata(df, columns_obs_only=["visit_occurrence_id", groupby])
else:
adata = ep.ad.df_to_anndata(df, columns_obs_only=["visit_occurrence_id"])
keys = obsm_key

violin_partial = partial(
sc.pl.violin,
keys=keys,
log=log,
use_raw=use_raw,
stripplot=stripplot,
jitter=jitter,
size=size,
layer=layer,
scale=scale,
order=order,
multi_panel=multi_panel,
xlabel=xlabel,
ylabel=ylabel,
rotation=rotation,
show=show,
save=save,
ax=ax,
**kwds,
)

return violin_partial(adata=adata, groupby=groupby)
Loading

0 comments on commit 4a9e960

Please sign in to comment.