-
Notifications
You must be signed in to change notification settings - Fork 31
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Move predict
from Turing
#716
base: master
Are you sure you want to change the base?
Changes from all commits
1c1c907
bdf90b4
c7d08b0
a425c41
41471f6
90d99ca
ea23b7c
76ef40f
304b63e
53b6749
fcd7c3d
3dc742a
30208ec
bf38627
fd1277b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,6 +42,156 @@ | |
return keys(c.info.varname_to_symbol) | ||
end | ||
|
||
""" | ||
predict([rng::AbstractRNG,] model::Model, chain::MCMCChains.Chains; include_all=false) | ||
|
||
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
in `chain`, and return the resulting `Chains`. | ||
|
||
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
predictive distribution. | ||
|
||
# Examples | ||
```jldoctest | ||
julia> using DynamicPPL, AbstractMCMC, AdvancedHMC, ForwardDiff; | ||
|
||
julia> @model function linear_reg(x, y, σ = 0.1) | ||
β ~ Normal(0, 1) | ||
for i ∈ eachindex(y) | ||
y[i] ~ Normal(β * x[i], σ) | ||
end | ||
end; | ||
|
||
julia> σ = 0.1; f(x) = 2 * x + 0.1 * randn(); | ||
|
||
julia> Δ = 0.1; xs_train = 0:Δ:10; ys_train = f.(xs_train); | ||
|
||
julia> xs_test = [10 + Δ, 10 + 2 * Δ]; ys_test = f.(xs_test); | ||
|
||
julia> m_train = linear_reg(xs_train, ys_train, σ); | ||
|
||
julia> n_train_logdensity_function = DynamicPPL.LogDensityFunction(m_train, DynamicPPL.VarInfo(m_train)); | ||
|
||
julia> chain_lin_reg = AbstractMCMC.sample(n_train_logdensity_function, NUTS(0.65), 200; chain_type=MCMCChains.Chains, param_names=[:β], discard_initial=100) | ||
┌ Info: Found initial step size | ||
└ ϵ = 0.003125 | ||
|
||
julia> m_test = linear_reg(xs_test, Vector{Union{Missing, Float64}}(undef, length(ys_test)), σ); | ||
|
||
julia> predictions = predict(m_test, chain_lin_reg) | ||
Object of type Chains, with data of type 100×2×1 Array{Float64,3} | ||
|
||
Iterations = 1:100 | ||
Thinning interval = 1 | ||
Chains = 1 | ||
Samples per chain = 100 | ||
parameters = y[1], y[2] | ||
|
||
2-element Array{ChainDataFrame,1} | ||
|
||
Summary Statistics | ||
parameters mean std naive_se mcse ess r_hat | ||
────────── ─────── ────── ──────── ─────── ──────── ────── | ||
y[1] 20.1974 0.1007 0.0101 missing 101.0711 0.9922 | ||
y[2] 20.3867 0.1062 0.0106 missing 101.4889 0.9903 | ||
|
||
Quantiles | ||
parameters 2.5% 25.0% 50.0% 75.0% 97.5% | ||
────────── ─────── ─────── ─────── ─────── ─────── | ||
y[1] 20.0342 20.1188 20.2135 20.2588 20.4188 | ||
y[2] 20.1870 20.3178 20.3839 20.4466 20.5895 | ||
|
||
julia> ys_pred = vec(mean(Array(group(predictions, :y)); dims = 1)); | ||
|
||
julia> sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
true | ||
``` | ||
""" | ||
function DynamicPPL.predict( | ||
rng::DynamicPPL.Random.AbstractRNG, | ||
model::DynamicPPL.Model, | ||
chain::MCMCChains.Chains; | ||
include_all=false, | ||
) | ||
parameter_only_chain = MCMCChains.get_sections(chain, :parameters) | ||
prototypical_varinfo = DynamicPPL.VarInfo(model) | ||
|
||
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) | ||
predictive_samples = map(iters) do (sample_idx, chain_idx) | ||
varinfo = deepcopy(prototypical_varinfo) | ||
DynamicPPL.setval_and_resample!( | ||
varinfo, parameter_only_chain, sample_idx, chain_idx | ||
) | ||
model(rng, varinfo, DynamicPPL.SampleFromPrior()) | ||
|
||
vals = DynamicPPL.values_as_in_model(model, varinfo) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually changing the behavior from Turing.jl's implementation. This will result in also including variables used in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooooh nice catch; thanks! Hmm, uncertain if this is desired behavior though 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I saw your issue on We would need to make a minor release of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But isn't this the purpose of this PR? To move the
Whether we're using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Ideally, I would want this PR to do a proper implementation of
what I was trying to say is that, with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Improving it in a separate PR sounds good, but please create an issue to track @torfjelde's comment. |
||
varname_vals = mapreduce( | ||
collect, | ||
vcat, | ||
map(DynamicPPL.varname_and_value_leaves, keys(vals), values(vals)), | ||
) | ||
|
||
return (varname_and_values=varname_vals, logp=DynamicPPL.getlogp(varinfo)) | ||
end | ||
|
||
chain_result = reduce( | ||
MCMCChains.chainscat, | ||
[ | ||
_predictive_samples_to_chains(predictive_samples[:, chain_idx]) for | ||
chain_idx in 1:size(predictive_samples, 2) | ||
], | ||
) | ||
parameter_names = if include_all | ||
MCMCChains.names(chain_result, :parameters) | ||
else | ||
filter( | ||
k -> !(k in MCMCChains.names(parameter_only_chain, :parameters)), | ||
names(chain_result, :parameters), | ||
) | ||
end | ||
return chain_result[parameter_names] | ||
end | ||
|
||
function _predictive_samples_to_arrays(predictive_samples) | ||
variable_names_set = DynamicPPL.OrderedCollections.OrderedSet{DynamicPPL.VarName}() | ||
|
||
sample_dicts = map(predictive_samples) do sample | ||
varname_value_pairs = sample.varname_and_values | ||
varnames = map(first, varname_value_pairs) | ||
values = map(last, varname_value_pairs) | ||
for varname in varnames | ||
push!(variable_names_set, varname) | ||
end | ||
|
||
return DynamicPPL.OrderedCollections.OrderedDict(zip(varnames, values)) | ||
end | ||
|
||
variable_names = collect(variable_names_set) | ||
variable_values = [ | ||
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts), | ||
key in variable_names | ||
] | ||
|
||
return variable_names, variable_values | ||
end | ||
|
||
function _predictive_samples_to_chains(predictive_samples) | ||
variable_names, variable_values = _predictive_samples_to_arrays(predictive_samples) | ||
variable_names_symbols = map(Symbol, variable_names) | ||
|
||
internal_parameters = [:lp] | ||
log_probabilities = reshape([sample.logp for sample in predictive_samples], :, 1) | ||
|
||
parameter_names = [variable_names_symbols; internal_parameters] | ||
parameter_values = hcat(variable_values, log_probabilities) | ||
parameter_values = MCMCChains.concretize(parameter_values) | ||
|
||
return MCMCChains.Chains( | ||
parameter_values, parameter_names, (internals=internal_parameters,) | ||
) | ||
end | ||
|
||
""" | ||
generated_quantities(model::Model, chain::MCMCChains.Chains) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1203,6 +1203,22 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC | |
end | ||
end | ||
|
||
""" | ||
predict([rng::AbstractRNG,] model::Model, chain; include_all=false) | ||
|
||
Sample from the posterior predictive distribution by executing `model` with parameters fixed to each sample | ||
in `chain`. | ||
|
||
If `include_all` is `false`, the returned `Chains` will contain only those variables that were not fixed by | ||
the samples in `chain`. This is useful when you want to sample only new variables from the posterior | ||
predictive distribution. | ||
""" | ||
function predict(model::Model, chain; include_all=false) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Turing.jl we're currently overloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. agree with this, but probably not time yet. Definitely after TuringLang/AbstractPPL.jl#81 is merged 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But is this PR then held up until that PR is merged then? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, that PR doesn't really matter; overloading There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Grey area: for me it is okay, because this PR is just about introduce a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If nothing significant is missing in TuringLang/AbstractPPL.jl#81, let's merge it and overload |
||
# this is only defined in `ext/DynamicPPLMCMCChainsExt.jl` | ||
# TODO: add other methods for different type of `chain` arguments: e.g., `VarInfo`, `NamedTuple`, and `OrderedDict` | ||
return predict(Random.default_rng(), model, chain; include_all) | ||
end | ||
Comment on lines
+1216
to
+1220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If so, we should definitively inform the user of this, no? Otherwise they'll just be like "oh why is this not defined?" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we want to export would There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If Turing exports it, it's better for DynamicPPL to export it, too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree, I was proposing delaying this until a good |
||
|
||
""" | ||
generated_quantities(model::Model, parameters::NamedTuple) | ||
generated_quantities(model::Model, values, keys) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" | ||
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" | ||
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" | ||
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, this doesn't quite seem worth it to test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I didn't add anything or change the implementation in this PR. Agree AHMC is heavy dep, but tests like https://github.com/TuringLang/DynamicPPL.jl/blob/fd1277b7201477448d3257cab65557b850bcf5b4/test/ext/DynamicPPLMCMCChainsExt.jl#L48C1-L55C45 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, but should just replace them with samples from the prior or something. This is just checking that the statistics are correct; it doesn't matter if these statistics are from the prior or posterior 🤷 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, would it be really bad to make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can't look at this PR properly until Wednesday, but in https://github.com/TuringLang/DynamicPPL.jl/pull/733/files#diff-3981168ff1709b3f48c35e40f491c26d9b91fc29373e512f1272f3b928cea6c0 I wrote a function that generates a chain by sampling from the prior. (It's called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sunxd3, @penelopeysm the posterior of Bayesian linear regression can be obtained in closed form (i.e. it is a Gaussian, see here). I suggest that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though the closed-form posterior is a good idea, there's really no need to run this test on posterior samples:) These were just some stats that were picked to have something to compare to; prior chain is the way to go I think 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. prior chain make sense: should we generate samples from prior, take out samples of a particular variable, and try to predict it? |
||
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" | ||
|
@@ -32,6 +33,7 @@ AbstractMCMC = "5" | |
AbstractPPL = "0.8.4, 0.9" | ||
Accessors = "0.1" | ||
Bijectors = "0.13.9, 0.14, 0.15" | ||
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" | ||
Combinatorics = "1" | ||
Compat = "4.3.0" | ||
Distributions = "0.25" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,3 +7,170 @@ | |
@test size(chain_generated) == (1000, 1) | ||
@test mean(chain_generated) ≈ 0 atol = 0.1 | ||
end | ||
|
||
@testset "predict" begin | ||
DynamicPPL.Random.seed!(100) | ||
|
||
@model function linear_reg(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
|
||
for i in eachindex(y) | ||
y[i] ~ Normal(β * x[i], σ) | ||
end | ||
end | ||
|
||
@model function linear_reg_vec(x, y, σ=0.1) | ||
β ~ Normal(0, 1) | ||
return y ~ MvNormal(β .* x, σ^2 * I) | ||
end | ||
|
||
f(x) = 2 * x + 0.1 * randn() | ||
|
||
Δ = 0.1 | ||
xs_train = 0:Δ:10 | ||
ys_train = f.(xs_train) | ||
xs_test = [10 + Δ, 10 + 2 * Δ] | ||
ys_test = f.(xs_test) | ||
|
||
# Infer | ||
m_lin_reg = linear_reg(xs_train, ys_train) | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg), | ||
AdvancedHMC.NUTS(0.65), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really doesn't seem necessary to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same reason as above: some tests relies on the quality of the samples |
||
1000; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
|
||
# Predict on two last indices | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
ys_pred = vec(mean(Array(group(predictions, :y)); dims=1)) | ||
|
||
# test like this depends on the variance of the posterior | ||
# this only makes sense if the posterior variance is about 0.002 | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
|
||
# Ensure that `rng` is respected | ||
predictions1 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
predictions2 = let rng = MersenneTwister(42) | ||
DynamicPPL.predict(rng, m_lin_reg_test, chain_lin_reg[1:2]) | ||
end | ||
@test all(Array(predictions1) .== Array(predictions2)) | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec, :y)); dims=1)) | ||
|
||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
|
||
# Multiple chains | ||
chain_lin_reg = sample( | ||
DynamicPPL.LogDensityFunction(m_lin_reg, DynamicPPL.VarInfo(m_lin_reg)), | ||
AdvancedHMC.NUTS(0.65), | ||
MCMCThreads(), | ||
1000, | ||
2; | ||
chain_type=MCMCChains.Chains, | ||
param_names=[:β], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
m_lin_reg_test = linear_reg(xs_test, fill(missing, length(ys_test))) | ||
predictions = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
@test size(chain_lin_reg, 3) == size(predictions, 3) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred = vec(mean(Array(group(predictions[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred) ≤ 0.1 | ||
end | ||
|
||
# Predict on two last indices for vectorized | ||
m_lin_reg_test = linear_reg_vec(xs_test, missing) | ||
predictions_vec = DynamicPPL.predict(m_lin_reg_test, chain_lin_reg) | ||
|
||
for chain_idx in MCMCChains.chains(chain_lin_reg) | ||
ys_pred_vec = vec(mean(Array(group(predictions_vec[:, :, chain_idx], :y)); dims=1)) | ||
@test sum(abs2, ys_test - ys_pred_vec) ≤ 0.1 | ||
end | ||
|
||
# https://github.com/TuringLang/Turing.jl/issues/1352 | ||
@model function simple_linear1(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ MvNormal(zeros(2), I) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear2(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef ~ filldist(Normal(0, 1), 2) | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear3(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef = Vector(undef, 2) | ||
for i in axes(coef, 1) | ||
coef[i] ~ Normal(0, 1) | ||
end | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
@model function simple_linear4(x, y) | ||
intercept ~ Normal(0, 1) | ||
coef1 ~ Normal(0, 1) | ||
coef2 ~ Normal(0, 1) | ||
coef = [coef1, coef2] | ||
coef = reshape(coef, 1, size(x, 1)) | ||
|
||
mu = vec(intercept .+ coef * x) | ||
error ~ truncated(Normal(0, 1), 0, Inf) | ||
return y ~ MvNormal(mu, error^2 * I) | ||
end | ||
|
||
x = randn(2, 100) | ||
y = [1 + 2 * a + 3 * b for (a, b) in eachcol(x)] | ||
|
||
param_names = Dict( | ||
simple_linear1 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear2 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear3 => [:intercept, Symbol("coef[1]"), Symbol("coef[2]"), :error], | ||
simple_linear4 => [:intercept, :coef1, :coef2, :error], | ||
) | ||
@testset "$model" for model in | ||
[simple_linear1, simple_linear2, simple_linear3, simple_linear4] | ||
m = model(x, y) | ||
chain = sample( | ||
DynamicPPL.LogDensityFunction(m), | ||
AdvancedHMC.NUTS(0.65), | ||
400; | ||
initial_params=rand(4), | ||
chain_type=MCMCChains.Chains, | ||
param_names=param_names[model], | ||
discard_initial=100, | ||
n_adapt=100, | ||
) | ||
chain_predict = DynamicPPL.predict(model(x, missing), chain) | ||
mean_prediction = [mean(chain_predict["y[$i]"].data) for i in 1:length(y)] | ||
@test mean(abs2, mean_prediction - y) ≤ 1e-3 | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here: no need to use
AdvancedHMC
(or any of the other packages), just construct theChains
by hand.This also doesn't actually show that you need to import
MCMCChains
for this to work, which might be a good idea