Skip to content

Commit

Permalink
Merge pull request #57 from CDCgov:55-pipe-output-to-a-tidybayes-comp…
Browse files Browse the repository at this point in the history
…liant-output

New function `spread_draws` for creating `tidybayes` compliant MCMC output
  • Loading branch information
seabbs authored Feb 21, 2024
2 parents 184f7e9 + 048fd0a commit 2e18d00
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 3 deletions.
2 changes: 2 additions & 0 deletions EpiAware/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Samuel Abbott <[email protected]>", "Samuel Brand <[email protected]>", "Zacha
version = "0.1.0-DEV"

[deps]
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Expand All @@ -26,4 +27,5 @@ Random = "1.9"
ReverseDiff = "1.15"
SparseArrays = "1.10"
Turing = "0.30"
DataFramesMeta = "0.14"
julia = "1.9"
5 changes: 3 additions & 2 deletions EpiAware/src/EpiAware.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ using Distributions,
ReverseDiff,
Optim,
Parameters,
QuadGK
QuadGK,
DataFramesMeta

# Exported utilities
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors
export create_discrete_pmf, default_rw_priors, default_delay_obs_priors, spread_draws

# Exported types
export EpiData, Renewal, ExpGrowthRate, DirectInfections
Expand Down
21 changes: 21 additions & 0 deletions EpiAware/src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,24 @@ function generate_observation_kernel(delay_int, time_horizon)
end
return K
end

"""
spread_draws(chn::Chains)
Converts a `Chains` object into a DataFrame in `tidybayes` format.
# Arguments
- `chn::Chains`: The `Chains` object to be converted.
# Returns
- `df::DataFrame`: The converted DataFrame.
"""
function spread_draws(chn::Chains)
df = DataFrame(chn)
df = hcat(DataFrame(draw = 1:size(df, 1)), df)
@rename!(df, $(".draw") = :draw)
@rename!(df, $(".chain") = :chain)
@rename!(df, $(".iteration") = :iteration)

return df
end
2 changes: 2 additions & 0 deletions EpiAware/test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFramesMeta = "1313f7d8-7da2-5740-9ea0-a2ca25f37964"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
16 changes: 15 additions & 1 deletion EpiAware/test/predictive_checking/toy_model_log_infs_RW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ TestEnv.activate()
=#

# using TestEnv # Run in Test environment mode
## using TestEnv # Run in Test environment mode
# TestEnv.activate()

using EpiAware
Expand All @@ -67,6 +67,9 @@ using StatsPlots
using Random
using DynamicPPL
using Statistics
using DataFramesMeta
using CSV # For outputting the MCMC chain

Random.seed!(0)

#=
Expand Down Expand Up @@ -179,6 +182,7 @@ scatter!(
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking",
ylims = (-0.5, maximum(truth_data) * 2.5),
)

#=
Expand All @@ -196,4 +200,14 @@ scatter!(
xlabel = "Time",
ylabel = "Cases",
title = "Posterior Predictive Checking",
ylims = (-0.5, maximum(gen.I_t) * 1.5),
)

#=
## Outputing the MCMC chain
We can use `spread_draws` to convert the MCMC chain into a tidybayes format.
=#

df_chn = spread_draws(chn)
save_path = joinpath(@__DIR__, "assets/toy_model_log_infs_RW_draws.csv")
CSV.write(save_path, df_chn)
18 changes: 18 additions & 0 deletions EpiAware/test/test_utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,21 @@ end
end

end
@testitem "Testing spread_draws function" begin
using DataFramesMeta, Turing

# Test case 1: Testing with non-empty Chains object
@testset "Test case 1" begin
X = rand(100, 2, 3)
chn = Chains(X, [:a, :b])
expected_df = DataFrame()
expected_df[!, ".draw"] = 1:300
expected_df[!, ".iteration"] = repeat(1:100, 3)
expected_df[!, ".chain"] = vcat(fill(1, 100), fill(2, 100), fill(3, 100))
expected_df.a = X[:, 1, :] |> vec
expected_df.b = X[:, 2, :] |> vec

df = spread_draws(chn)
@test df == expected_df
end
end

0 comments on commit 2e18d00

Please sign in to comment.