Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor hospital admission to use delphi_utils create_export_csv #2032

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion _delphi_utils_python/delphi_utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def create_export_csv(
"sample_size",
"missing_val",
"missing_se",
"missing_sample_size"
"missing_sample_size",
]
export_df = df[df["timestamp"] == date].filter(items=expected_columns)
if "missing_val" in export_df.columns:
Expand All @@ -129,4 +129,11 @@ def create_export_csv(
if sort_geos:
export_df = export_df.sort_values(by="geo_id")
export_df.to_csv(export_file, index=False, na_rep="NA")

logger.debug(
"Wrote rows",
num_rows=df.size,
geo_type=geo_res,
num_geo_ids=export_df["geo_id"].unique().size,
)
return dates
23 changes: 18 additions & 5 deletions _delphi_utils_python/tests/test_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for exporting CSV files."""
import logging
from datetime import datetime
from os import listdir
from os.path import join
Expand All @@ -11,6 +12,7 @@

from delphi_utils import create_export_csv, Nans

TEST_LOGGER = logging.getLogger()

def _set_df_dtypes(df: pd.DataFrame, dtypes: Dict[str, Any]) -> pd.DataFrame:
assert all(isinstance(e, type) or isinstance(e, str) for e in dtypes.values()), (
Expand Down Expand Up @@ -102,6 +104,7 @@ def test_export_with_metric(self, tmp_path):
metric="deaths",
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -122,6 +125,7 @@ def test_export_rounding(self, tmp_path):
metric="deaths",
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)
assert_frame_equal(
pd.read_csv(join(tmp_path, "20200215_county_deaths_test.csv")),
Expand All @@ -144,6 +148,7 @@ def test_export_without_metric(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -163,6 +168,7 @@ def test_export_with_limiting_start_date(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -182,6 +188,7 @@ def test_export_with_limiting_end_date(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -199,6 +206,7 @@ def test_export_with_no_dates(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand Down Expand Up @@ -228,7 +236,8 @@ def test_export_with_null_removal(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
remove_null_samples=True
remove_null_samples=True,
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand Down Expand Up @@ -259,7 +268,8 @@ def test_export_without_null_removal(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
remove_null_samples=False
remove_null_samples=False,
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -275,7 +285,7 @@ def test_export_without_null_removal(self, tmp_path):
def test_export_df_without_missingness(self, tmp_path):

create_export_csv(
df=self.DF.copy(), export_dir=tmp_path, geo_res="county", sensor="test"
df=self.DF.copy(), export_dir=tmp_path, geo_res="county", sensor="test", logger=TEST_LOGGER
)
df = pd.read_csv(join(tmp_path, "20200215_county_test.csv")).astype(
{"geo_id": str, "sample_size": int}
Expand All @@ -297,6 +307,7 @@ def test_export_df_with_missingness(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)
assert set(listdir(tmp_path)) == set(
[
Expand Down Expand Up @@ -358,7 +369,8 @@ def test_export_sort(self, tmp_path):
unsorted_df,
export_dir=tmp_path,
geo_res="county",
sensor="test"
sensor="test",
logger=TEST_LOGGER
)
expected_df = pd.DataFrame({
"geo_id": ["51175", "51093"],
Expand All @@ -374,7 +386,8 @@ def test_export_sort(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
sort_geos=True
sort_geos=True,
logger=TEST_LOGGER
)
expected_df = pd.DataFrame({
"geo_id": ["51093", "51175"],
Expand Down
3 changes: 2 additions & 1 deletion changehc/delphi_changehc/update_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def write_to_csv(df, geo_level, write_se, day_shift, out_name, logger, output_pa
start_date=start_date,
end_date=end_date,
sensor=out_name,
write_empty_days=True
write_empty_days=True,
logger=logger,
)
logger.debug("wrote {0} rows for {1} {2}".format(
df.size, df["geo_id"].unique().size, geo_level
Expand Down
23 changes: 15 additions & 8 deletions claims_hosp/delphi_claims_hosp/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,20 @@
when the module is run with `python -m delphi_claims_hosp`.
"""

# standard packages
import time
import os
import time
from datetime import datetime, timedelta
from pathlib import Path

# third party
from delphi_utils import get_structured_logger
from delphi_utils.export import create_export_csv

# first party
from .backfill import merge_backfill_file, store_backfill_file
from .config import Config
from .download_claims_ftp_files import download
from .modify_claims_drops import modify_and_write
from .get_latest_claims_name import get_latest_filename
from .modify_claims_drops import modify_and_write
from .update_indicator import ClaimsHospIndicatorUpdater
from .backfill import (store_backfill_file, merge_backfill_file)


def run_module(params):
Expand Down Expand Up @@ -137,11 +135,20 @@ def run_module(params):
params["indicator"]["write_se"],
signal_name
)
updater.update_indicator(
output = updater.update_indicator(
claims_file,
params["common"]["export_dir"],
logger,
)
filtered_output_df = updater.preprocess_output(output)
create_export_csv(
filtered_output_df,
export_dir=params["common"]["export_dir"],
start_date=startdate,
geo_res=geo,
sensor=signal_name,
logger=logger,
)

max_dates.append(updater.output_dates[-1])
n_csv_export.append(len(updater.output_dates))
logger.info("finished updating", geo = geo)
Expand Down
136 changes: 52 additions & 84 deletions claims_hosp/delphi_claims_hosp/update_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,15 @@ def geo_reindex(self, data):
data_frame.fillna(0, inplace=True)
return data_frame

def update_indicator(self, input_filepath, outpath, logger):
def update_indicator(self, input_filepath, logger):
"""
Generate and output indicator values.

Args:
input_filepath: path to the aggregated claims data
outpath: output path for the csv results

"""
self.shift_dates()
final_output_inds = \
(self.burn_in_dates >= self.startdate) & (self.burn_in_dates <= self.enddate)
final_output_inds = (self.burn_in_dates >= self.startdate) & (self.burn_in_dates <= self.enddate)

# load data
base_geo = Config.HRR_COL if self.geo == Config.HRR_COL else Config.FIPS_COL
Expand All @@ -164,22 +161,19 @@ def update_indicator(self, input_filepath, outpath, logger):
if self.weekday
else None
)
# run fitting code (maybe in parallel)
rates = {}
std_errs = {}
valid_inds = {}
df_lst = []
output_df = pd.DataFrame()
aysim319 marked this conversation as resolved.
Show resolved Hide resolved
if not self.parallel:
for geo_id, sub_data in data_frame.groupby(level=0):
sub_data.reset_index(inplace=True)
if self.weekday:
sub_data = Weekday.calc_adjustment(
wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data.set_index(Config.DATE_COL, inplace=True)
res = ClaimsHospIndicator.fit(sub_data, self.burnindate, geo_id)
res = pd.DataFrame(res)
rates[geo_id] = np.array(res.loc[final_output_inds, "rate"])
std_errs[geo_id] = np.array(res.loc[final_output_inds, "se"])
valid_inds[geo_id] = np.array(res.loc[final_output_inds, "incl"])
temp_df = pd.DataFrame(res)
temp_df = temp_df.loc[final_output_inds]
df_lst.append(pd.DataFrame(temp_df))
output_df = pd.concat(df_lst)
else:
n_cpu = min(Config.MAX_CPU_POOL, cpu_count())
logging.debug("starting pool with %d workers", n_cpu)
Expand All @@ -188,83 +182,57 @@ def update_indicator(self, input_filepath, outpath, logger):
for geo_id, sub_data in data_frame.groupby(level=0, as_index=False):
sub_data.reset_index(inplace=True)
if self.weekday:
sub_data = Weekday.calc_adjustment(
wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data.set_index(Config.DATE_COL, inplace=True)
pool_results.append(
pool.apply_async(
ClaimsHospIndicator.fit,
args=(sub_data, self.burnindate, geo_id,),
args=(
sub_data,
self.burnindate,
geo_id,
),
)
)
pool_results = [proc.get() for proc in pool_results]
for res in pool_results:
geo_id = res["geo_id"]
res = pd.DataFrame(res)
rates[geo_id] = np.array(res.loc[final_output_inds, "rate"])
std_errs[geo_id] = np.array(res.loc[final_output_inds, "se"])
valid_inds[geo_id] = np.array(res.loc[final_output_inds, "incl"])

# write out results
unique_geo_ids = list(rates.keys())
output_dict = {
"rates": rates,
"se": std_errs,
"dates": self.output_dates,
"geo_ids": unique_geo_ids,
"geo_level": self.geo,
"include": valid_inds,
}

self.write_to_csv(output_dict, outpath)
logging.debug("wrote files to %s", outpath)

def write_to_csv(self, output_dict, output_path="./receiving"):
df_lst = [pd.DataFrame(proc.get()).loc([final_output_inds]) for proc in pool_results]
output_df = pd.concat(df_lst)

return output_df

def preprocess_output(self, df) -> pd.DataFrame:
"""
Write values to csv.
Check for any anomlies and formats the output for exports.

Args:
output_dict: dictionary containing values, se, unique dates, and unique geo_id
output_path: outfile path to write the csv
Parameters
----------
df

Returns
-------
df
"""
filtered_df = df[df["incl"]]
filtered_df = filtered_df.reset_index()
filtered_df.rename(columns={"rate": "val"}, inplace=True)
filtered_df["timestamp"] = filtered_df["timestamp"].astype(str)
df_list = []
if self.write_se:
logging.info("========= WARNING: WRITING SEs TO %s =========",
self.signal_name)

geo_level = output_dict["geo_level"]
dates = output_dict["dates"]
geo_ids = output_dict["geo_ids"]
all_rates = output_dict["rates"]
all_se = output_dict["se"]
all_include = output_dict["include"]
out_n = 0
for i, date in enumerate(dates):
filename = "%s/%s_%s_%s.csv" % (
output_path,
(date + Config.DAY_SHIFT).strftime("%Y%m%d"),
geo_level,
self.signal_name,
)
with open(filename, "w") as outfile:
outfile.write("geo_id,val,se,direction,sample_size\n")
for geo_id in geo_ids:
val = all_rates[geo_id][i]
se = all_se[geo_id][i]
if all_include[geo_id][i]:
assert not np.isnan(val), "value for included value is nan"
assert not np.isnan(se), "se for included rate is nan"
if val > 90:
logging.warning("value suspicious, %s: %d", geo_id, val)
assert se < 5, f"se suspicious, {geo_id}: {se}"
if self.write_se:
assert val > 0 and se > 0, "p=0, std_err=0 invalid"
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, val, se, "NA", "NA"))
else:
# for privacy reasons we will not report the standard error
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, val, "NA", "NA", "NA"))
out_n += 1

logging.debug("wrote %d rows for %d %s", out_n, len(geo_ids), geo_level)
logging.info("WARNING: WRITING SEs")
for geo_id, group in filtered_df.groupby("geo_id"):
assert not group.val.isnull().any()
assert not group.se.isnull().any()
assert np.all(group.se < 5), f"se suspicious, {geo_id}: {np.where(group.se >= 5)[0]}"
if np.any(group.val > 90):
for sus_val in np.where(group.val > 90):
logging.warning("value suspicious, %s: %d", geo_id, sus_val)
if self.write_se:
assert np.all(group.val > 0) and np.all(group.se > 0), "p=0, std_err=0 invalid"
else:
group["se"] = np.NaN
group["sample_size"] = np.NaN
df_list.append(group)
Copy link
Contributor

@dshemetov dshemetov Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rereading this bit, I'm thinking: df_list seems unnecessary. These are all just asserts that terminate the pipeline if any of the values don't pass validation, so we should just run the asserts, but don't rebuild the df. The only difference between filtered_df and output_df are the group["se"] = np.NaN and group["sample_size"] = np.NaN transformations, but those are independent of group, so can be handled outside the for-loop. It might even make sense to handle

        filtered_df = df[df["incl"]]
        filtered_df = filtered_df.reset_index()
        filtered_df.rename(columns={"rate": "val"}, inplace=True)
        filtered_df["timestamp"] = filtered_df["timestamp"].astype(str)

and adding NAs at the end of update_indicator, call this function validate_dataframe, and don't have it return anything.

Copy link
Contributor

@dshemetov dshemetov Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I'd just try removing the df_list and output_df though and do something like this after the for loop with all the assert statements.

if not self.write_se:
    filtered_df["se"] = np.NaN
filtered_df["sample_size"] = np.NaN
filtered_df.drop(columns=["incl"], inplace = True)
assert sorted(list(filtered_df.columns)) == ["geo_id", "sample_size", "se", "timestamp", "val"]
return filtered_df

Copy link
Contributor

@dshemetov dshemetov Sep 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious (a) if that works, (b) how much that speeds things up. I'd guess that this for loop is really expensive because it runs over all counties (update_indicator does too, but at least we parallelize that). There may be a way to avoid the for loop using DataFrame methods, but we can get there later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually the create_export_csv is what's slowing down the runs. the preprocessing takes about half the time as the previous write_to_csv. It's the delphi utils create_export_csv that's slower than the previous versions.
Still moved the nan columns outside of the loop though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, I see what you mean. Really surprised create_export_csv is that much slower, that's a bit unfortunate.


output_df = pd.concat(df_list)
output_df.drop(columns=["incl"], inplace=True)
assert sorted(list(output_df.columns)) == ["geo_id", "sample_size", "se", "timestamp", "val"]
return output_df
Loading
Loading