Skip to content

Commit

Permalink
Merge pull request #130 from CDCgov/127-nicer-constructor-for-varianc…
Browse files Browse the repository at this point in the history
…e-priors

Create `_make_halfnormal_prior` function
  • Loading branch information
seabbs authored Mar 11, 2024
2 parents 178e3ae + e36b550 commit 649567f
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 17 deletions.
18 changes: 11 additions & 7 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ In `EpiAware` we provide a constructor for random walk latent models with priors
```math
\begin{align}
Z_0 &\sim \mathcal{N}(0,1),\\
\sigma^2_Z &\sim \text{HalfNormal}(0.01).
\sigma_{RW} &\sim \text{HalfNormal}(0.1 * \sqrt{\pi} / \sqrt{2})).
\end{align}
```
"

# ╔═╡ 56ae496b-0094-460b-89cb-526627991717
rwp = EpiAware.RandomWalk(Normal(),
truncated(Normal(0.0, 0.02), 0.0, Inf))
EpiAware._make_halfnormal_prior(0.1))

# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44
md"
Expand Down Expand Up @@ -179,7 +179,7 @@ We choose a simple observation model where infections are observed 0, 1, 2, 3 da
obs_model = EpiAware.DelayObservations(
fill(0.25, 4),
time_horizon,
truncated(Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf)
EpiAware._make_halfnormal_prior(0.1)
)

# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b
Expand Down Expand Up @@ -298,12 +298,15 @@ init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1]))

# ╔═╡ 4deb3a51-781d-48c4-91f6-6adf2b1affcf
md"
**NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.**
**NB: We are running this inference run for speed rather than accuracy as a demonstration. You might want to use a higher target acceptance and more samples in a typical workflow.**
"

# ╔═╡ 946b1c43-e750-40c9-9f14-79da9735e437
target_acc_prob = 0.8

# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
chn = sample(inference_mdl,
NUTS(; adtype = AutoReverseDiff(true)),
NUTS(target_acc_prob; adtype = AutoReverseDiff(true)),
MCMCThreads(),
250,
4;
Expand Down Expand Up @@ -360,7 +363,7 @@ As well as checking the posterior predictions for latent infections, we can also

# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3
let
parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor)
parameters_to_plot = (:σ_RW, :neg_bin_cluster_factor)

plts = map(parameters_to_plot) do name
var_samples = chn[name] |> vec
Expand Down Expand Up @@ -446,10 +449,11 @@ end
# ╠═b4033728-b321-4100-8194-1fd9fe2d268d
# ╟─9222b436-9445-4039-abbf-25c8cddb7f63
# ╠═197a4fbb-b71a-475a-bb78-28ff613e3094
# ╠═073a1d40-456a-450e-969f-11b23eb7fd1f
# ╟─073a1d40-456a-450e-969f-11b23eb7fd1f
# ╠═0379b058-4c35-440a-bc01-aafa0178bdbf
# ╠═a7798f71-9bb5-4506-9476-0cc11553b9e2
# ╟─4deb3a51-781d-48c4-91f6-6adf2b1affcf
# ╠═946b1c43-e750-40c9-9f14-79da9735e437
# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1
# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/latentmodels/randomwalk.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D
var_prior::S
std_prior::S
end

function default_rw_priors()
Expand All @@ -10,9 +10,8 @@ end

@model function generate_latent(latent_model::RandomWalk, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_model.var_prior
σ_RW ~ latent_model.std_prior
rw_init ~ latent_model.init_prior
σ_RW = sqrt(σ²_RW)
rw = Vector{eltype(ϵ_t)}(undef, n)

rw[1] = rw_init + σ_RW * ϵ_t[1]
Expand Down
16 changes: 14 additions & 2 deletions EpiAware/src/utils/dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ function r_to_R(r, w::AbstractVector)
end

"""
NegativeBinomialMeanClust(μ, α)
Compute the mean-cluster factor negative binomial distribution.
# Arguments
Expand All @@ -167,3 +165,17 @@ function NegativeBinomialMeanClust(μ, α)
r = μ^2 / ex_σ²
return NegativeBinomial(r, p)
end

"""
Create a half-normal prior distribution with the specified mean.
# Arguments
- `prior_mean::AbstractFloat`: The mean of the prior distribution.
# Returns
- `Truncated{Normal}`: The half-normal prior distribution.
"""
function _make_halfnormal_prior(prior_mean::AbstractFloat)
return truncated(Normal(0.0, prior_mean * sqrt(pi) / sqrt(2)), 0.0, Inf)
end
8 changes: 4 additions & 4 deletions EpiAware/test/test_latent-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
rw_process = EpiAware.RandomWalk(Normal(0.0, 1.0),
truncated(Normal(0.0, 0.05), 0.0, Inf))
model = EpiAware.generate_latent(rw_process, n)
fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
fixed_model = fix(model, (σ_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 = sample(fixed_model, Prior(), n_samples) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
Expand All @@ -33,8 +33,8 @@ end
end
@testset "Testing RandomWalk constructor" begin
init_prior = Normal(0.0, 1.0)
var_prior = truncated(Normal(0.0, 0.05), 0.0, Inf)
rw_process = RandomWalk(init_prior, var_prior)
std_prior = truncated(Normal(0.0, 0.05), 0.0, Inf)
rw_process = RandomWalk(init_prior, std_prior)
@test rw_process.init_prior == init_prior
@test rw_process.var_prior == var_prior
@test rw_process.std_prior == std_prior
end
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-models.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "Testing delay obs against theoretical properties" begin
using DynamicPPL, Turing, Distributions
using HypothesisTests#: ExactOneSampleKSTest, pvalue
using HypothesisTests

# Set up test data with fixed infection
I_t = [10.0, 20.0, 30.0]
Expand Down
23 changes: 23 additions & 0 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,26 @@ end
@test df == expected_df
end
end

@testitem "Testing _make_halfnormal_prior" begin
using Distributions, HypothesisTests
@testset "Check distribution type" begin
prior_mean = 10.0
prior_dist = EpiAware._make_halfnormal_prior(prior_mean)
@test typeof(prior_dist) <: Distribution
end

@testset "Check distribution properties" begin
prior_mean = 2.0
prior_dist = EpiAware._make_halfnormal_prior(prior_mean)
#Check Distributions.jl mean function
@test mean(prior_dist) prior_mean
samples = rand(prior_dist, 10_000)
#Check mean from direct sampling of folded distribution and ANOVA and Variance F test comparisons
direct_samples = randn(10_000) * prior_mean * sqrt(pi) / sqrt(2) .|> abs
mean_pval = OneWayANOVATest(samples, direct_samples) |> pvalue
@test mean_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented
var_pval = VarianceFTest(samples, direct_samples) |> pvalue
@test var_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented
end
end

0 comments on commit 649567f

Please sign in to comment.