Skip to content

Commit

Permalink
Change pos_shift to be a property of Observation model (#147)
Browse files Browse the repository at this point in the history
* Change `pos_shift` into a property of ObservationModel

* Update tests
  • Loading branch information
SamuelBrand1 authored Mar 15, 2024
1 parent 1556a25 commit 147c09e
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 26 deletions.
2 changes: 1 addition & 1 deletion EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.39
# v0.19.40

using Markdown
using InteractiveUtils
Expand Down
17 changes: 10 additions & 7 deletions EpiAware/src/EpiObsModels/delayobservations.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObservationModel
delay_kernel::SparseMatrixCSC{T, Integer}
neg_bin_cluster_factor_prior::S
pos_shift::T

function DelayObservations(delay_int,
time_horizon,
neg_bin_cluster_factor_prior)
neg_bin_cluster_factor_prior;
pos_shift = 1e-6)
@assert all(delay_int .>= 0) "Delay interval must be non-negative"
@assert sum(delay_int)1 "Delay interval must sum to 1"

K = generate_observation_kernel(delay_int, time_horizon)

new{eltype(K), typeof(neg_bin_cluster_factor_prior)}(K,
neg_bin_cluster_factor_prior)
neg_bin_cluster_factor_prior, pos_shift)
end

function DelayObservations(;
delay_distribution::ContinuousDistribution,
time_horizon::Integer,
neg_bin_cluster_factor_prior::Sampleable,
D_delay,
Δd = 1.0)
Δd = 1.0,
pos_shift = 1e-6)
delay_int = create_discrete_pmf(delay_distribution; Δd = Δd, D = D_delay)
return DelayObservations(delay_int, time_horizon, neg_bin_cluster_factor_prior)
return DelayObservations(
delay_int, time_horizon, neg_bin_cluster_factor_prior; pos_shift)
end
end

Expand All @@ -32,14 +36,13 @@ end

@model function EpiAwareBase.generate_observations(observation_model::DelayObservations,
y_t,
I_t;
pos_shift)
I_t)

#Parameters
neg_bin_cluster_factor ~ observation_model.neg_bin_cluster_factor_prior

#Predictive distribution
expected_obs = observation_model.delay_kernel * I_t .+ pos_shift
expected_obs = observation_model.delay_kernel * I_t .+ observation_model.pos_shift

if ismissing(y_t)
y_t = Vector{Int}(undef, length(expected_obs))
Expand Down
7 changes: 3 additions & 4 deletions EpiAware/src/make_epi_aware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
time_steps;
epi_model::AbstractEpiModel,
latent_model::AbstractLatentModel,
observation_model::AbstractObservationModel,
pos_shift = 1e-6)
observation_model::AbstractObservationModel
)
#Latent process
@submodel Z_t, latent_model_aux = generate_latent(
latent_model,
Expand All @@ -15,8 +15,7 @@
#Predictive distribution of ascerted cases
@submodel generated_y_t, generated_y_t_aux = generate_observations(observation_model,
y_t,
I_t;
pos_shift = pos_shift)
I_t)

#Generate quantities
return (;
Expand Down
18 changes: 10 additions & 8 deletions EpiAware/test/test_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
time_horizon = time_horizon,
neg_bin_cluster_factor_prior = Gamma(5, 0.05 / 5),
D_delay = D_delay,
Δd = Δd)
Δd = Δd;
pos_shift = pos_shift)

# Create full epi model and sample from it
test_mdl = make_epi_aware(y_t, time_horizon; epi_model = epi_model,
latent_model = rwp,
observation_model = obs_model, pos_shift)
observation_model = obs_model)
gen = generated_quantities(test_mdl, rand(test_mdl))

#Check model sampled
Expand Down Expand Up @@ -57,15 +58,15 @@ end
time_horizon = 5
obs_model = DelayObservations([1.0],
time_horizon,
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0))
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0);
pos_shift)

# Create full epi model and sample from it
test_mdl = make_epi_aware(y_t,
time_horizon;
epi_model = epi_model,
latent_model = rwp,
observation_model = obs_model,
pos_shift)
observation_model = obs_model)

chn = sample(test_mdl, Prior(), 1000)
gens = generated_quantities(test_mdl, chn)
Expand Down Expand Up @@ -96,15 +97,16 @@ end
time_horizon = 5
obs_model = DelayObservations([1.0],
time_horizon,
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0))
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0);
pos_shift)

# Create full epi model and sample from it
test_mdl = make_epi_aware(y_t,
time_horizon;
epi_model = epi_model,
latent_model = rwp,
observation_model = obs_model,
pos_shift)
observation_model = obs_model
)

chn = sample(test_mdl, Prior(), 1000)
gens = generated_quantities(test_mdl, chn)
Expand Down
12 changes: 6 additions & 6 deletions EpiAware/test/test_observation-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@

# Delay kernel is just event observed on same day
delay_obs = DelayObservations([1.0], length(I_t),
obs_prior[:neg_bin_cluster_factor_prior])
obs_prior[:neg_bin_cluster_factor_prior];
pos_shift = 1e-6)

# Set up priors
neg_bin_cf = 0.05

# Call the function
mdl = generate_observations(delay_obs,
missing,
I_t;
pos_shift = 1e-6)
I_t)
fix_mdl = fix(mdl, (neg_bin_cluster_factor = neg_bin_cf,))

n_samples = 1000
Expand Down Expand Up @@ -52,18 +52,18 @@ end
# Define a common setup for your model that can be reused across different y_t scenarios
obs_prior = default_delay_obs_priors()
delay_obs = DelayObservations(
[1.0], length(I_t), obs_prior[:neg_bin_cluster_factor_prior])
[1.0], length(I_t), obs_prior[:neg_bin_cluster_factor_prior];
pos_shift = 1e-6)
neg_bin_cf = 0.05 # Set up priors
# Expected point estimate calculation setup
pos_shift = 1e-6

# Test each y_t scenario
for (scenario_name, y_t_scenario) in [("fully observed", y_t_fully_observed),
("partially observed", y_t_partially_observed),
("fully unobserved", y_t_fully_unobserved)]
@testset "$scenario_name y_t" begin
mdl = generate_observations(
delay_obs, y_t_scenario, I_t; pos_shift = pos_shift)
delay_obs, y_t_scenario, I_t)
sampled_obs = sample(mdl, Prior(), 1000) |>
chn -> generated_quantities(mdl, chn) .|>
(gen -> gen[1]) |>
Expand Down

0 comments on commit 147c09e

Please sign in to comment.