Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing the prior for negative binomial observations #123

Merged
merged 7 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions EpiAware/docs/src/examples/getting_started.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
### A Pluto.jl notebook ###
# v0.19.39
# v0.19.40

using Markdown
using InteractiveUtils
Expand Down Expand Up @@ -149,7 +149,7 @@ as the action of a sparse kernel on the infections $I(t)$.
```math
\begin{align}
y_t &\sim \text{NegBinomial}(\mu = \sum_{s\geq 0} K[t, t-s] I(s), r), \\
1 / r &\sim \text{Gamma}(3, 0.05/3).
1 / \sqrt{r} &\sim \text{HalfNormal}\Big(0.1 \sqrt{{\pi \over 2}}\Big).
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
\end{align}
```
"
Expand All @@ -171,7 +171,7 @@ We choose a simple observation model where infections are observed 0, 1, 2, 3 da
obs_model = EpiAware.DelayObservations(
fill(0.25, 4),
time_horizon,
truncated(Gamma(5, 0.05 / 5), 1e-3, 1.0)
truncated(Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf)
SamuelBrand1 marked this conversation as resolved.
Show resolved Hide resolved
)

# ╔═╡ e49713e8-4840-4083-8e3f-fc52d791be7b
Expand Down Expand Up @@ -314,8 +314,6 @@ let
size = (700, 400))
end

# ╔═╡ 2293b711-0bd0-44d5-8a30-94e56c5e4c65

# ╔═╡ fd6321b1-4c3a-4123-b0dc-c45b951e0b80
md"
As well as checking the posterior predictions for latent infections, we can also check how well inference recovered unknown parameters, such as the random walk variance or the cluster factor of the negative binomial observations.
Expand Down Expand Up @@ -386,7 +384,7 @@ end
# ╠═6639e66f-7725-4976-81b2-6472419d1a62
# ╟─df5e59f8-3185-4bed-9cca-7c266df17cec
# ╠═6fbdd8e6-2323-4352-9185-1f31a9cf9012
# ╟─5e62a50a-71f4-4902-b1c9-fdf51fe145fa
# ╠═5e62a50a-71f4-4902-b1c9-fdf51fe145fa
# ╟─e813d547-6100-4c43-b84c-8cebe306bda8
# ╠═c7580ae6-0db5-448e-8b20-4dd6fcdb1ae0
# ╟─0aa3fcbd-0831-45b8-9a2c-7ffbabf5895f
Expand All @@ -401,7 +399,7 @@ end
# ╠═d073e63b-62da-4743-ace0-78ef7806bc0b
# ╠═a04f3c1b-7e11-4800-9c2a-9fc0021de6e7
# ╟─f68b4e41-ac5c-42cd-a8c2-8761d66f7543
# ╠═b5bc8f05-b538-4abf-aa84-450bf2dff3d9
# ╟─b5bc8f05-b538-4abf-aa84-450bf2dff3d9
# ╠═c8ce0d46-a160-4c40-a055-69b3d10d1770
# ╟─4a4c6e91-8d8f-4bbf-bb7e-a36dc281e312
# ╠═259a7042-e74f-43c7-aeb4-97a3beeb7776
Expand All @@ -410,7 +408,6 @@ end
# ╠═3eb5ec5e-aae7-478e-84fb-80f2e9f85b4c
# ╟─30498cc7-16a5-441a-b8cd-c19b220c60c1
# ╠═e9df22b8-8e4d-4ab7-91ea-c01f2239b3e5
# ╠═2293b711-0bd0-44d5-8a30-94e56c5e4c65
# ╟─fd6321b1-4c3a-4123-b0dc-c45b951e0b80
# ╠═10d8fe24-83a6-47ac-97b7-a374481473d3
# ╟─81efe8ca-b753-4a12-bafc-a887a999377b
Expand Down
5 changes: 3 additions & 2 deletions EpiAware/src/observation-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ struct DelayObservations{T <: AbstractFloat, S <: Sampleable} <: AbstractObserva
end

function default_delay_obs_priors()
return (:neg_bin_cluster_factor_prior => Gamma(3, 0.05 / 3),) |> Dict
return (:neg_bin_cluster_factor_prior => truncated(
Normal(0, 0.1 * sqrt(pi) / sqrt(2)), 0.0, Inf),) |> Dict
end

function generate_observations(observation_model::AbstractObservationModel,
Expand Down Expand Up @@ -54,7 +55,7 @@ end

for i in eachindex(y_t)
y_t[i] ~ NegativeBinomialMeanClust(
expected_obs[i], neg_bin_cluster_factor
expected_obs[i], neg_bin_cluster_factor^2
)
end

Expand Down
2 changes: 1 addition & 1 deletion EpiAware/test/test_observation-models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
chn -> generated_quantities(fix_mdl, chn) .|>
(gen -> gen[1][1]) |>
vec
direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf) |>
direct_samples = EpiAware.NegativeBinomialMeanClust(I_t[1], neg_bin_cf^2) |>
dist -> rand(dist, n_samples)

#For discrete distributions, checking mean and variance is as expected
Expand Down
Loading