Skip to content

Commit

Permalink
update for correct data source
Browse files Browse the repository at this point in the history
  • Loading branch information
damonbayer committed Sep 16, 2024
1 parent 672a800 commit 7fa5c1b
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 65 deletions.
20 changes: 12 additions & 8 deletions nssp_demo/fit_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
autoreg_p_hosp_rv,
autoreg_rt_rv,
eta_sd_rv,
generation_interval_pmf_rv,
hosp_wday_effect_rv,
i0_first_obs_n_rv,
inf_feedback_strength_rv,
infection_feedback_pmf_rv,
initialization_rate_rv,
log_r_mu_intercept_rv,
p_hosp_mean_rv,
Expand All @@ -30,15 +28,13 @@
import pyrenew_covid_wastewater.plotting as plotting
from pyrenew_covid_wastewater.hosp_only_ww_model import hosp_only_ww_model

n_chains = 4
numpyro.set_host_device_count(n_chains)

# read this from cli
model_dir = Path(
"private_data/r_2024-09-10_f_2024-03-13_l_2024-09-09_t_2024-08-14/CA"
)

n_chains = 4
numpyro.set_host_device_count(n_chains)


data_path = model_dir / "data_for_model_fit.json"

with open(
Expand All @@ -50,7 +46,15 @@

inf_to_hosp_rv = DeterministicVariable(
"inf_to_hosp", jnp.array(model_data["inf_to_hosp_pmf"])
)
) # check if off by 1 or reversed

generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf", jnp.array(model_data["generation_interval_pmf"])
) # check if off by 1 or reversed

infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf", jnp.array(model_data["generation_interval_pmf"])
) # check if off by 1 or reversed

data_observed_hospital_admissions = jnp.array(
model_data["data_observed_hospital_admissions"]
Expand Down
48 changes: 42 additions & 6 deletions nssp_demo/prep_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ library(tidyverse)
library(CFAEpiNow2Pipeline)
library(usa)
library(fs)
library(arrow)
library(here)


Expand Down Expand Up @@ -40,17 +41,52 @@ prep_data <- function(report_date = today(),
filter(abb == state_abb) %>%
pull(population)

inf_to_hosp_pmf <-
read_csv(here(path("nssp_demo", "private_data", "latest", ext = "csv"))) %>%
filter(geo_value == str_to_lower(state_abb)) %>%
arrange(delay) %>%
pull(p)
nnh_estimates <- read_parquet(
here(path("nssp_demo",
"private_data",
"prod",
ext = "parquet"
))
)

generation_interval_pmf <-
nnh_estimates %>%
filter(
is.na(geo_value),
disease == "COVID-19",
parameter == "generation_interval"
) %>%
pull(value) %>%
pluck(1)


delay_pmf <-
nnh_estimates %>%
filter(
is.na(geo_value),
disease == "COVID-19",
parameter == "delay"
) %>%
pull(value) %>%
pluck(1)

right_truncation_pmf <-
nnh_estimates %>%
filter(
geo_value == state_abb,
disease == "COVID-19",
parameter == "right_truncation"
) %>%
pull(value) %>%
pluck(1)


list(
prepped_date = prepped_data,
data_for_model_fit = list(
inf_to_hosp_pmf = inf_to_hosp_pmf,
inf_to_hosp_pmf = delay_pmf,
generation_interval_pmf = generation_interval_pmf,
right_truncation_pmf = right_truncation_pmf,
data_observed_hospital_admissions = train_ed_admissions,
test_ed_admissions = test_ed_admissions,
state_pop = state_pop
Expand Down
51 changes: 0 additions & 51 deletions nssp_demo/priors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import jax.numpy as jnp
import numpyro.distributions as dist
import pyrenew.transformation as transformation
from pyrenew.deterministic import DeterministicVariable
from pyrenew.randomvariable import DistributionalVariable, TransformedVariable

from pyrenew_covid_wastewater.utils import convert_to_logmean_log_sd
Expand Down Expand Up @@ -36,53 +35,6 @@
autoreg_rt_rv = DistributionalVariable("autoreg_rt", dist.Beta(2, 40))


generation_interval_pmf_rv = DeterministicVariable(
"generation_interval_pmf",
jnp.array(
[
0.161701189933765,
0.320525743089203,
0.242198071982593,
0.134825252524032,
0.0689141939998525,
0.0346219683116734,
0.017497710736154,
0.00908172017279556,
0.00483656086299504,
0.00260732346885217,
0.00143298046642562,
0.00082002579123121,
0.0004729600977183,
0.000284420637980485,
0.000179877924728358,
]
),
)


infection_feedback_pmf_rv = DeterministicVariable(
"infection_feedback_pmf",
jnp.array(
[
0.161701189933765,
0.320525743089203,
0.242198071982593,
0.134825252524032,
0.0689141939998525,
0.0346219683116734,
0.017497710736154,
0.00908172017279556,
0.00483656086299504,
0.00260732346885217,
0.00143298046642562,
0.00082002579123121,
0.0004729600977183,
0.000284420637980485,
0.000179877924728358,
]
),
)

inf_feedback_strength_rv = TransformedVariable(
"inf_feedback",
DistributionalVariable(
Expand Down Expand Up @@ -135,6 +87,3 @@


uot = 55

# state_pop
# depends on state

0 comments on commit 7fa5c1b

Please sign in to comment.