Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New function spread_draws for creating tidybayes compliant MCMC output #57

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Samuel Abbott <[email protected]>", "Samuel Brand <[email protected]>", "Zacha
version = "0.1.0-DEV"

[deps]
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -26,4 +27,5 @@ Random = "1.9"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Turing = "0.30"
DataFramesMeta = "0.14"
julia = "1.9"
5 changes: 3 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ using Distributions,
ReverseDiff,
Optim,
Parameters,
QuadGK
QuadGK,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_draws

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
Expand Down
21 changes: 21 additions & 0 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,24 @@ function generate_observation_kernel(delay_int, time_horizon)
end
return K
end

"""
spread_draws(chn::Chains)

Converts a `Chains` object into a DataFrame in `tidybayes` format.

# Arguments
- `chn::Chains`: The `Chains` object to be converted.

# Returns
- `df::DataFrame`: The converted DataFrame.
"""
function spread_draws(chn::Chains)
df = DataFrame(chn)
df = hcat(DataFrame(draw = 1:size(df, 1)), df)
@rename!(df, $(".draw") = :draw)
@rename!(df, $(".chain") = :chain)
@rename!(df, $(".iteration") = :iteration)

return df
end
2 changes: 2 additions & 0 deletions EpiAware/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
16 changes: 15 additions & 1 deletion EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TestEnv.activate()

=#

# using TestEnv # Run in Test environment mode
## using TestEnv # Run in Test environment mode
# TestEnv.activate()

using EpiAware
Expand All @@ -67,6 +67,9 @@ using StatsPlots
using Random
using DynamicPPL
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this intensionally missing things from the test environmment? Like StatsPlots?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatsPlots is in the test env but not a dependency of EpiAware; which I think is the right move.

using Statistics
using DataFramesMeta
using CSV # For outputting the MCMC chain

Random.seed!(0)

#=
Expand Down Expand Up @@ -179,6 +182,7 @@ scatter!(
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking",
ylims = (-0.5, maximum(truth_data) * 2.5),
)

#=
Expand All @@ -196,4 +200,14 @@ scatter!(
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking",
ylims = (-0.5, maximum(gen.I_t) * 1.5),
)

#=
## Outputing the MCMC chain
We can use `spread_draws` to convert the MCMC chain into a tidybayes format.
=#

df_chn = spread_draws(chn)
save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv")
CSV.write(save_path, df_chn)
18 changes: 18 additions & 0 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,21 @@ end
end

end
@testitem "Testing spread_draws function" begin
using DataFramesMeta, Turing

# Test case 1: Testing with non-empty Chains object
@testset "Test case 1" begin
X = rand(100, 2, 3)
chn = Chains(X, [:a, :b])
expected_df = DataFrame()
expected_df[!, ".draw"] = 1:300
expected_df[!, ".iteration"] = repeat(1:100, 3)
expected_df[!, ".chain"] = vcat(fill(1, 100), fill(2, 100), fill(3, 100))
expected_df.a = X[:, 1, :] |> vec
expected_df.b = X[:, 2, :] |> vec

df = spread_draws(chn)
@test df == expected_df
end
end
Loading