diff --git a/Project.toml b/Project.toml index ae402a8aa..754900be2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.23.16" +version = "0.23.17" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" @@ -16,32 +16,34 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" -[weakdeps] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - -[extensions] -DynamicPPLMCMCChainsExt = ["MCMCChains"] - -[extras] -MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" - [compat] AbstractMCMC = "2, 3.0, 4" AbstractPPL = "0.6" BangBang = "0.3" Bijectors = "0.13" ChainRulesCore = "0.9.7, 0.10, 1" -ConstructionBase = "1" +ConstructionBase = "1.5.4" Distributions = "0.23.8, 0.24, 0.25" DocStringExtensions = "0.8, 0.9" LogDensityProblems = "2" -MacroTools = "0.5.6" MCMCChains = "6" +MacroTools = "0.5.6" OrderedCollections = "1" +Requires = "1" Setfield = "0.7.1, 0.8, 1" ZygoteRules = "0.2" julia = "1.6" + +[extensions] +DynamicPPLMCMCChainsExt = ["MCMCChains"] + +[extras] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" + +[weakdeps] +MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" diff --git a/ext/DynamicPPLMCMCChainsExt.jl b/ext/DynamicPPLMCMCChainsExt.jl index de77f58f2..2630e9d1b 100644 --- a/ext/DynamicPPLMCMCChainsExt.jl +++ b/ext/DynamicPPLMCMCChainsExt.jl @@ -1,15 +1,48 @@ module DynamicPPLMCMCChainsExt -using DynamicPPL: DynamicPPL -using MCMCChains: MCMCChains +if isdefined(Base, :get_extension) + using DynamicPPL: DynamicPPL + using MCMCChains: MCMCChains +else + using ..DynamicPPL: DynamicPPL + using ..MCMCChains: MCMCChains +end + +_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names +function _check_varname_indexing(c::MCMCChains.Chains) + return DynamicPPL.supports_varname_indexing(c) || + error("Chains do not support indexing using $vn.") +end + +# A few methods needed. +function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains) + return _has_varname_to_symbol(chain.info) +end +function DynamicPPL.getindex_varname( + c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx +) + _check_varname_indexing(c) + return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx] +end +function DynamicPPL.varnames(c::MCMCChains.Chains) + _check_varname_indexing(c) + return keys(c.info.varname_to_symbol) +end -function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains) - chain_parameters = MCMCChains.get_sections(chain, :parameters) +function DynamicPPL.generated_quantities( + model::DynamicPPL.Model, chain_full::MCMCChains.Chains +) + chain = MCMCChains.get_sections(chain_full, :parameters) varinfo = DynamicPPL.VarInfo(model) iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3)) return map(iters) do (sample_idx, chain_idx) - DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx) - model(varinfo) + # Update the varinfo with the current sample and make variables not present in `chain` + # to be sampled. + DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx) + + # TODO: Some of the variables can be a view into the `varinfo`, so we need to + # `deepcopy` the `varinfo` before passing it to `model`. + model(deepcopy(varinfo)) end end diff --git a/src/DynamicPPL.jl b/src/DynamicPPL.jl index 1fd008ffe..58e357f1f 100644 --- a/src/DynamicPPL.jl +++ b/src/DynamicPPL.jl @@ -154,6 +154,7 @@ const LEGACY_WARNING = """ # Necessary forward declarations include("utils.jl") include("selector.jl") +include("chains.jl") include("model.jl") include("sampler.jl") include("varname.jl") @@ -175,4 +176,16 @@ include("logdensityfunction.jl") include("model_utils.jl") include("extract_priors.jl") +if !isdefined(Base, :get_extension) + using Requires +end + +@static if !isdefined(Base, :get_extension) + function __init__() + @require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include( + "../ext/DynamicPPLMCMCChainsExt.jl" + ) + end +end + end # module diff --git a/src/chains.jl b/src/chains.jl new file mode 100644 index 000000000..fd6564e5b --- /dev/null +++ b/src/chains.jl @@ -0,0 +1,25 @@ +""" + supports_varname_indexing(chain::AbstractChains) + +Return `true` if `chain` supports indexing using `VarName` in place of the +variable name index. +""" +supports_varname_indexing(::AbstractChains) = false + +""" + getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx) + +Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`. + +Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). +""" +function getindex_varname end + +""" + varnames(chains::AbstractChains) + +Return an iterator over the varnames present in `chains`. + +Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref). +""" +function varnames end diff --git a/src/utils.jl b/src/utils.jl index 294bc75ce..0135e4c24 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -501,6 +501,42 @@ function splitlens(condition, lens) return current_parent, current_child, condition(current_parent) end +""" + remove_parent_lens(vn_parent::VarName, vn_child::VarName) + +Remove the parent lens `vn_parent` from `vn_child`. + +# Examples +```jldoctest +julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a)) +(@lens _.a) + +julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1])) +(@lens _.a[1]) + +julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1])) +(@lens _[1]) + +julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b)) +(@lens _[1].b) + +julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a)) +ERROR: Could not find x.a in x.a + +julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1])) +ERROR: Could not find x.a[2] in x.a[1] +``` +""" +function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym} + _, child, issuccess = splitlens(getlens(vn_child)) do lens + l = lens === nothing ? Setfield.IdentityLens() : lens + VarName(vn_child, l) == vn_parent + end + + issuccess || error("Could not find $vn_parent in $vn_child") + return child +end + # HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233 # and https://github.com/JuliaFolds/BangBang.jl/pull/238. # HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`. diff --git a/src/varinfo.jl b/src/varinfo.jl index fbe3f6088..ddb4caffb 100644 --- a/src/varinfo.jl +++ b/src/varinfo.jl @@ -1064,6 +1064,41 @@ end return Expr(:||, false, out...) end +function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName) + return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) +end +function nested_setindex_maybe!( + vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym} +) where {names,sym} + return if sym in names + _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn) + else + nothing + end +end +function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName) + # If `vn` is in `vns`, then we can just use the standard `setindex!`. + vns = md.vns + if vn in vns + setindex!(vi, val, vn) + return vn + end + + # Otherwise, we need to check if either of the `vns` subsumes `vn`. + i = findfirst(Base.Fix2(subsumes, vn), vns) + i === nothing && return nothing + + vn_parent = vns[i] + dist = getdist(md, vn_parent) + val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here. + # Split the varname into its tail lens. + lens = remove_parent_lens(vn_parent, vn) + # Update the value for the parent. + val_parent_updated = set!!(val_parent, lens, val) + setindex!(vi, val_parent_updated, vn_parent) + return vn_parent +end + # The default getindex & setindex!() for get & set values # NOTE: vi[vn] will always transform the variable to its original space and Julia type getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn)) @@ -1131,7 +1166,8 @@ The value(s) may or may not be transformed to Euclidean space. """ setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi) function BangBang.setindex!!(vi::VarInfo, val, vn::VarName) - return (setindex!(vi, val, vn); return vi) + setindex!(vi, val, vn) + return vi end """ @@ -1600,7 +1636,26 @@ end function setval_and_resample!( vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int ) - return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) + if supports_varname_indexing(chains) + # First we need to set every variable to be resampled. + for vn in keys(vi) + set_flag!(vi, vn, "del") + end + # Then we set the variables in `varinfo` from `chain`. + for vn in varnames(chains) + vn_updated = nested_setindex_maybe!( + vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn + ) + + # Unset the `del` flag if we found something. + if vn_updated !== nothing + # NOTE: This will be triggered even if only a subset of a variable has been set! + unset_flag!(vi, vn_updated, "del") + end + end + else + setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains)) + end end function _setval_and_resample_kernel!( diff --git a/test/Project.toml b/test/Project.toml index b36a7e23a..ade5ade1a 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -2,6 +2,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" @@ -24,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" AbstractMCMC = "2.1, 3.0, 4" AbstractPPL = "0.6" Bijectors = "0.13" +Compat = "4.3.0" Distributions = "0.25" DistributionsAD = "0.6.3" Documenter = "0.26.1, 0.27" diff --git a/test/model.jl b/test/model.jl index 481aa4e38..fa7f5de47 100644 --- a/test/model.jl +++ b/test/model.jl @@ -278,4 +278,55 @@ end @test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x) end end + + @testset "generated_quantities on `LKJCholesky`" begin + n = 10 + d = 2 + model = DynamicPPL.TestUtils.demo_lkjchol(d) + xs = [model().x for _ in 1:n] + + # Extract varnames and values. + vns_and_vals_xs = map( + collect ∘ Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs + ) + vns = map(first, first(vns_and_vals_xs)) + vals = map(vns_and_vals_xs) do vns_and_vals + map(last, vns_and_vals) + end + + # Construct the chain. + syms = map(Symbol, vns) + vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms)) + + chain = MCMCChains.Chains( + permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,) + ) + display(chain) + + # Test! + results = generated_quantities(model, chain) + for (x_true, result) in zip(xs, results) + @test x_true.UL == result.x.UL + end + + # With variables that aren't in the `model`. + vns_to_syms_with_extra = let d = deepcopy(vns_to_syms) + d[@varname(y)] = :y + d + end + vals_with_extra = map(enumerate(vals)) do (i, v) + vcat(v, i) + end + chain_with_extra = MCMCChains.Chains( + permutedims(stack(vals_with_extra)), + vcat(syms, [:y]); + info=(varname_to_symbol=vns_to_syms_with_extra,), + ) + display(chain_with_extra) + # Test! + results = generated_quantities(model, chain_with_extra) + for (x_true, result) in zip(xs, results) + @test x_true.UL == result.x.UL + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 74cabc272..43d68386c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -11,6 +11,7 @@ using MCMCChains using Tracker using Zygote using Setfield +using Compat using Distributed using LinearAlgebra