Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor hospital admission to use delphi_utils create_export_csv #2032

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion _delphi_utils_python/delphi_utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def create_export_csv(
"sample_size",
"missing_val",
"missing_se",
"missing_sample_size"
"missing_sample_size",
]
export_df = df[df["timestamp"] == date].filter(items=expected_columns)
if "missing_val" in export_df.columns:
Expand All @@ -129,4 +129,11 @@ def create_export_csv(
if sort_geos:
export_df = export_df.sort_values(by="geo_id")
export_df.to_csv(export_file, index=False, na_rep="NA")

logger.debug(
"Wrote rows",
num_rows=df.size,
geo_type=geo_res,
num_geo_ids=export_df["geo_id"].unique().size,
)
return dates
23 changes: 18 additions & 5 deletions _delphi_utils_python/tests/test_export.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tests for exporting CSV files."""
import logging
from datetime import datetime
from os import listdir
from os.path import join
Expand All @@ -11,6 +12,7 @@

from delphi_utils import create_export_csv, Nans

TEST_LOGGER = logging.getLogger()

def _set_df_dtypes(df: pd.DataFrame, dtypes: Dict[str, Any]) -> pd.DataFrame:
assert all(isinstance(e, type) or isinstance(e, str) for e in dtypes.values()), (
Expand Down Expand Up @@ -102,6 +104,7 @@ def test_export_with_metric(self, tmp_path):
metric="deaths",
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -122,6 +125,7 @@ def test_export_rounding(self, tmp_path):
metric="deaths",
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)
assert_frame_equal(
pd.read_csv(join(tmp_path, "20200215_county_deaths_test.csv")),
Expand All @@ -144,6 +148,7 @@ def test_export_without_metric(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -163,6 +168,7 @@ def test_export_with_limiting_start_date(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -182,6 +188,7 @@ def test_export_with_limiting_end_date(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -199,6 +206,7 @@ def test_export_with_no_dates(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand Down Expand Up @@ -228,7 +236,8 @@ def test_export_with_null_removal(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
remove_null_samples=True
remove_null_samples=True,
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand Down Expand Up @@ -259,7 +268,8 @@ def test_export_without_null_removal(self, tmp_path):
export_dir=tmp_path,
geo_res="state",
sensor="test",
remove_null_samples=False
remove_null_samples=False,
logger=TEST_LOGGER
)

assert set(listdir(tmp_path)) == set(
Expand All @@ -275,7 +285,7 @@ def test_export_without_null_removal(self, tmp_path):
def test_export_df_without_missingness(self, tmp_path):

create_export_csv(
df=self.DF.copy(), export_dir=tmp_path, geo_res="county", sensor="test"
df=self.DF.copy(), export_dir=tmp_path, geo_res="county", sensor="test", logger=TEST_LOGGER
)
df = pd.read_csv(join(tmp_path, "20200215_county_test.csv")).astype(
{"geo_id": str, "sample_size": int}
Expand All @@ -297,6 +307,7 @@ def test_export_df_with_missingness(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
logger=TEST_LOGGER
)
assert set(listdir(tmp_path)) == set(
[
Expand Down Expand Up @@ -358,7 +369,8 @@ def test_export_sort(self, tmp_path):
unsorted_df,
export_dir=tmp_path,
geo_res="county",
sensor="test"
sensor="test",
logger=TEST_LOGGER
)
expected_df = pd.DataFrame({
"geo_id": ["51175", "51093"],
Expand All @@ -374,7 +386,8 @@ def test_export_sort(self, tmp_path):
export_dir=tmp_path,
geo_res="county",
sensor="test",
sort_geos=True
sort_geos=True,
logger=TEST_LOGGER
)
expected_df = pd.DataFrame({
"geo_id": ["51093", "51175"],
Expand Down
3 changes: 2 additions & 1 deletion changehc/delphi_changehc/update_sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def write_to_csv(df, geo_level, write_se, day_shift, out_name, logger, output_pa
start_date=start_date,
end_date=end_date,
sensor=out_name,
write_empty_days=True
write_empty_days=True,
logger=logger,
)
logger.debug("wrote {0} rows for {1} {2}".format(
df.size, df["geo_id"].unique().size, geo_level
Expand Down
23 changes: 15 additions & 8 deletions claims_hosp/delphi_claims_hosp/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,20 @@
when the module is run with `python -m delphi_claims_hosp`.
"""

# standard packages
import time
import os
import time
from datetime import datetime, timedelta
from pathlib import Path

# third party
from delphi_utils import get_structured_logger
from delphi_utils.export import create_export_csv

# first party
from .backfill import merge_backfill_file, store_backfill_file
from .config import Config
from .download_claims_ftp_files import download
from .modify_claims_drops import modify_and_write
from .get_latest_claims_name import get_latest_filename
from .modify_claims_drops import modify_and_write
from .update_indicator import ClaimsHospIndicatorUpdater
from .backfill import (store_backfill_file, merge_backfill_file)


def run_module(params):
Expand Down Expand Up @@ -137,11 +135,20 @@ def run_module(params):
params["indicator"]["write_se"],
signal_name
)
updater.update_indicator(
output = updater.update_indicator(
claims_file,
params["common"]["export_dir"],
logger,
)
filtered_output_df = updater.preprocess_output(output)
create_export_csv(
filtered_output_df,
export_dir=params["common"]["export_dir"],
start_date=startdate,
geo_res=geo,
sensor=signal_name,
logger=logger,
)

max_dates.append(updater.output_dates[-1])
n_csv_export.append(len(updater.output_dates))
logger.info("finished updating", geo = geo)
Expand Down
135 changes: 51 additions & 84 deletions claims_hosp/delphi_claims_hosp/update_indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,18 +133,15 @@ def geo_reindex(self, data):
data_frame.fillna(0, inplace=True)
return data_frame

def update_indicator(self, input_filepath, outpath, logger):
def update_indicator(self, input_filepath, logger):
"""
Generate and output indicator values.

Args:
input_filepath: path to the aggregated claims data
outpath: output path for the csv results

"""
self.shift_dates()
final_output_inds = \
(self.burn_in_dates >= self.startdate) & (self.burn_in_dates <= self.enddate)
final_output_inds = (self.burn_in_dates >= self.startdate) & (self.burn_in_dates <= self.enddate)

# load data
base_geo = Config.HRR_COL if self.geo == Config.HRR_COL else Config.FIPS_COL
Expand All @@ -164,22 +161,18 @@ def update_indicator(self, input_filepath, outpath, logger):
if self.weekday
else None
)
# run fitting code (maybe in parallel)
rates = {}
std_errs = {}
valid_inds = {}
df_lst = []
if not self.parallel:
for geo_id, sub_data in data_frame.groupby(level=0):
sub_data.reset_index(inplace=True)
if self.weekday:
sub_data = Weekday.calc_adjustment(
wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data.set_index(Config.DATE_COL, inplace=True)
res = ClaimsHospIndicator.fit(sub_data, self.burnindate, geo_id)
res = pd.DataFrame(res)
rates[geo_id] = np.array(res.loc[final_output_inds, "rate"])
std_errs[geo_id] = np.array(res.loc[final_output_inds, "se"])
valid_inds[geo_id] = np.array(res.loc[final_output_inds, "incl"])
temp_df = pd.DataFrame(res)
temp_df = temp_df.loc[final_output_inds]
df_lst.append(pd.DataFrame(temp_df))
output_df = pd.concat(df_lst)
else:
n_cpu = min(Config.MAX_CPU_POOL, cpu_count())
logging.debug("starting pool with %d workers", n_cpu)
Expand All @@ -188,83 +181,57 @@ def update_indicator(self, input_filepath, outpath, logger):
for geo_id, sub_data in data_frame.groupby(level=0, as_index=False):
sub_data.reset_index(inplace=True)
if self.weekday:
sub_data = Weekday.calc_adjustment(
wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data = Weekday.calc_adjustment(wd_params, sub_data, ["num"], Config.DATE_COL)
sub_data.set_index(Config.DATE_COL, inplace=True)
pool_results.append(
pool.apply_async(
ClaimsHospIndicator.fit,
args=(sub_data, self.burnindate, geo_id,),
args=(
sub_data,
self.burnindate,
geo_id,
),
)
)
pool_results = [proc.get() for proc in pool_results]
for res in pool_results:
geo_id = res["geo_id"]
res = pd.DataFrame(res)
rates[geo_id] = np.array(res.loc[final_output_inds, "rate"])
std_errs[geo_id] = np.array(res.loc[final_output_inds, "se"])
valid_inds[geo_id] = np.array(res.loc[final_output_inds, "incl"])

# write out results
unique_geo_ids = list(rates.keys())
output_dict = {
"rates": rates,
"se": std_errs,
"dates": self.output_dates,
"geo_ids": unique_geo_ids,
"geo_level": self.geo,
"include": valid_inds,
}

self.write_to_csv(output_dict, outpath)
logging.debug("wrote files to %s", outpath)

def write_to_csv(self, output_dict, output_path="./receiving"):
df_lst = [pd.DataFrame(proc.get()).loc([final_output_inds]) for proc in pool_results]
output_df = pd.concat(df_lst)

return output_df

def preprocess_output(self, df) -> pd.DataFrame:
"""
Write values to csv.
Check for any anomlies and formats the output for exports.

Args:
output_dict: dictionary containing values, se, unique dates, and unique geo_id
output_path: outfile path to write the csv
Parameters
----------
df

Returns
-------
df
"""
filtered_df = df[df["incl"]]
filtered_df = filtered_df.reset_index()
filtered_df.rename(columns={"rate": "val"}, inplace=True)
filtered_df["timestamp"] = filtered_df["timestamp"].astype(str)
df_list = []
if self.write_se:
logging.info("========= WARNING: WRITING SEs TO %s =========",
self.signal_name)

geo_level = output_dict["geo_level"]
dates = output_dict["dates"]
geo_ids = output_dict["geo_ids"]
all_rates = output_dict["rates"]
all_se = output_dict["se"]
all_include = output_dict["include"]
out_n = 0
for i, date in enumerate(dates):
filename = "%s/%s_%s_%s.csv" % (
output_path,
(date + Config.DAY_SHIFT).strftime("%Y%m%d"),
geo_level,
self.signal_name,
)
with open(filename, "w") as outfile:
outfile.write("geo_id,val,se,direction,sample_size\n")
for geo_id in geo_ids:
val = all_rates[geo_id][i]
se = all_se[geo_id][i]
if all_include[geo_id][i]:
assert not np.isnan(val), "value for included value is nan"
assert not np.isnan(se), "se for included rate is nan"
if val > 90:
logging.warning("value suspicious, %s: %d", geo_id, val)
assert se < 5, f"se suspicious, {geo_id}: {se}"
if self.write_se:
assert val > 0 and se > 0, "p=0, std_err=0 invalid"
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, val, se, "NA", "NA"))
else:
# for privacy reasons we will not report the standard error
outfile.write(
"%s,%f,%s,%s,%s\n" % (geo_id, val, "NA", "NA", "NA"))
out_n += 1

logging.debug("wrote %d rows for %d %s", out_n, len(geo_ids), geo_level)
logging.info("WARNING: WRITING SEs")
for geo_id, group in filtered_df.groupby("geo_id"):
assert not group.val.isnull().any()
assert not group.se.isnull().any()
assert np.all(group.se < 5), f"se suspicious, {geo_id}: {np.where(group.se >= 5)[0]}"
if np.any(group.val > 90):
for sus_val in np.where(group.val > 90):
logging.warning("value suspicious, %s: %d", geo_id, sus_val)
df_list.append(group)
if self.write_se:
assert np.all(group.val > 0) and np.all(group.se > 0), "p=0, std_err=0 invalid"

output_df = pd.concat(df_list)
aysim319 marked this conversation as resolved.
Show resolved Hide resolved
output_df.drop(columns=["incl"], inplace=True)
if not self.write_se:
output_df["se"] = np.NaN
output_df["sample_size"] = np.NaN
assert sorted(list(output_df.columns)) == ["geo_id", "sample_size", "se", "timestamp", "val"]
return output_df
Loading
Loading