Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create _make_halfnormal_prior function and change to priors on standard deviations #130

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,14 +89,14 @@ In `EpiAware` we provide a constructor for random walk latent models with priors
```math
\begin{align}
Z_0 &\sim \mathcal{N}(0,1),\\
\sigma^2_Z &\sim \text{HalfNormal}(0.01).
\sigma_{RW} &\sim \text{HalfNormal}(0.1 * \sqrt{\pi} / \sqrt{2})).
\end{align}
```
"

# ╔═╡ 56ae496b-0094-460b-89cb-526627991717
rwp = EpiAware.RandomWalk(Normal(),
truncated(Normal(0.0, 0.02), 0.0, Inf))
EpiAware._make_halfnormal_prior(0.1))

# ╔═╡ 767beffd-1ef5-4e6c-9ac6-edb52e60fb44
md"
Expand Down Expand Up @@ -179,7 +179,7 @@ We choose a simple observation model where infections are observed 0, 1, 2, 3 da
obs_model = EpiAware.DelayObservations(
fill(0.25, 4),
time_horizon,
truncated(Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf)
EpiAware._make_halfnormal_prior(0.1)
)

# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b
Expand Down Expand Up @@ -298,12 +298,15 @@ init_params = collect.(eachrow(best_pf.draws_transformed.value[1:4, :, 1]))

# ╔═╡ 4deb3a51-781d-48c4-91f6-6adf2b1affcf
md"
**NB: We are running this inference run for speed rather than accuracy as a demonstration. Use a higher target acceptance and more samples in a typical workflow.**
**NB: We are running this inference run for speed rather than accuracy as a demonstration. You might want to use a higher target acceptance and more samples in a typical workflow.**
"

# ╔═╡ 946b1c43-e750-40c9-9f14-79da9735e437
target_acc_prob = 0.8

# ╔═╡ 3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
chn = sample(inference_mdl,
NUTS(; adtype = AutoReverseDiff(true)),
NUTS(target_acc_prob; adtype = AutoReverseDiff(true)),
MCMCThreads(),
250,
4;
Expand Down Expand Up @@ -360,7 +363,7 @@ As well as checking the posterior predictions for latent infections, we can also

# ╔═╡ 10d8fe24-83a6-47ac-97b7-a374481473d3
let
parameters_to_plot = (:σ²_RW, :neg_bin_cluster_factor)
parameters_to_plot = (:σ_RW, :neg_bin_cluster_factor)

plts = map(parameters_to_plot) do name
var_samples = chn[name] |> vec
Expand Down Expand Up @@ -446,10 +449,11 @@ end
# ╠═b4033728-b321-4100-8194-1fd9fe2d268d
# ╟─9222b436-9445-4039-abbf-25c8cddb7f63
# ╠═197a4fbb-b71a-475a-bb78-28ff613e3094
# ╠═073a1d40-456a-450e-969f-11b23eb7fd1f
# ╟─073a1d40-456a-450e-969f-11b23eb7fd1f
# ╠═0379b058-4c35-440a-bc01-aafa0178bdbf
# ╠═a7798f71-9bb5-4506-9476-0cc11553b9e2
# ╟─4deb3a51-781d-48c4-91f6-6adf2b1affcf
# ╠═946b1c43-e750-40c9-9f14-79da9735e437
# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1
# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5
Expand Down
5 changes: 2 additions & 3 deletions EpiAware/src/latentmodels/randomwalk.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
struct RandomWalk{D <: Sampleable, S <: Sampleable} <: AbstractLatentModel
init_prior::D
var_prior::S
std_prior::S
end

function default_rw_priors()
Expand All @@ -10,9 +10,8 @@ end

@model function generate_latent(latent_model::RandomWalk, n)
ϵ_t ~ MvNormal(ones(n))
σ²_RW ~ latent_model.var_prior
σ_RW ~ latent_model.std_prior
rw_init ~ latent_model.init_prior
σ_RW = sqrt(σ²_RW)
rw = Vector{eltype(ϵ_t)}(undef, n)

rw[1] = rw_init + σ_RW * ϵ_t[1]
Expand Down
16 changes: 14 additions & 2 deletions EpiAware/src/utils/dist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ function r_to_R(r, w::AbstractVector)
end

"""
NegativeBinomialMeanClust(μ, α)
seabbs marked this conversation as resolved.
Show resolved Hide resolved

Compute the mean-cluster factor negative binomial distribution.

# Arguments
Expand All @@ -167,3 +165,17 @@ function NegativeBinomialMeanClust(μ, α)
r = μ^2 / ex_σ²
return NegativeBinomial(r, p)
end

"""
Create a half-normal prior distribution with the specified mean.

# Arguments
- `prior_mean::AbstractFloat`: The mean of the prior distribution.

# Returns
- `Truncated{Normal}`: The half-normal prior distribution.

"""
function _make_halfnormal_prior(prior_mean::AbstractFloat)
return truncated(Normal(0.0, prior_mean * sqrt(pi) / sqrt(2)), 0.0, Inf)
end
8 changes: 4 additions & 4 deletions EpiAware/test/test_latent-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
rw_process = EpiAware.RandomWalk(Normal(0.0, 1.0),
truncated(Normal(0.0, 0.05), 0.0, Inf))
model = EpiAware.generate_latent(rw_process, n)
fixed_model = fix(model, (σ²_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
fixed_model = fix(model, (σ_RW = 1.0, init_rw_value = 0.0)) #Fixing the standard deviation of the random walk process
n_samples = 1000
samples_day_5 = sample(fixed_model, Prior(), n_samples) |>
chn -> mapreduce(vcat, generated_quantities(fixed_model, chn)) do gen
Expand All @@ -33,8 +33,8 @@ end
end
@testset "Testing RandomWalk constructor" begin
init_prior = Normal(0.0, 1.0)
var_prior = truncated(Normal(0.0, 0.05), 0.0, Inf)
rw_process = RandomWalk(init_prior, var_prior)
std_prior = truncated(Normal(0.0, 0.05), 0.0, Inf)
rw_process = RandomWalk(init_prior, std_prior)
@test rw_process.init_prior == init_prior
@test rw_process.var_prior == var_prior
@test rw_process.std_prior == std_prior
end
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-models.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
@testitem "Testing delay obs against theoretical properties" begin
using DynamicPPL, Turing, Distributions
using HypothesisTests#: ExactOneSampleKSTest, pvalue
using HypothesisTests

# Set up test data with fixed infection
I_t = [10.0, 20.0, 30.0]
Expand Down
23 changes: 23 additions & 0 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,26 @@ end
@test df == expected_df
end
end

@testitem "Testing _make_halfnormal_prior" begin
using Distributions, HypothesisTests
@testset "Check distribution type" begin
prior_mean = 10.0
prior_dist = EpiAware._make_halfnormal_prior(prior_mean)
@test typeof(prior_dist) <: Distribution
end

@testset "Check distribution properties" begin
prior_mean = 2.0
prior_dist = EpiAware._make_halfnormal_prior(prior_mean)
#Check Distributions.jl mean function
@test mean(prior_dist) ≈ prior_mean
samples = rand(prior_dist, 10_000)
#Check mean from direct sampling of folded distribution and ANOVA and Variance F test comparisons
direct_samples = randn(10_000) * prior_mean * sqrt(pi) / sqrt(2) .|> abs
mean_pval = OneWayANOVATest(samples, direct_samples) |> pvalue
@test mean_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented
var_pval = VarianceFTest(samples, direct_samples) |> pvalue
@test var_pval > 1e-6 #Very unlikely to fail if the model is correctly implemented
end
end
Loading