Skip to content

Commit

Permalink
Fix for generated_quantities (#534)
Browse files Browse the repository at this point in the history
* added method for extracting the child lens from a varname subsumed by
another varname

* added nested_getindex and nested_setindex! for VarInfo

* added ConstructionBase.setproperties implementation for `Cholesky`

* fixed minor formatting issue

* added `supports_varname_indexing` for chains and use this in generated_quantities

* use a private method rather than overloading getindex for Chains

* removed getindex overloads in nested_index testing

* moved generated_quantities tests to test/model.jl

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* will now also correctly set variables to be resampled, etc.

* Update test/model.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added Compat as a test dep so we can methods such as stack

* improved overload of ConstructionBase.setproperties

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added docstring to remove_parent_lens

* removed methods which are not useful for the purpose of this PR

* noticed we're incorrectly using chain rather than chain_params in generated_quantities

* Update ext/DynamicPPLMCMCChainsExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed doctests

* added Requires.jl

* Update src/DynamicPPL.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* bump patch version

* Update src/DynamicPPL.jl

Co-authored-by: David Widmann <[email protected]>

* moved new generated_quantities functionality into setval_and_resample!
so we can make use of this also for Turing.predict, etc.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update ext/DynamicPPLMCMCChainsExt.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/chains.jl

Co-authored-by: Xianda Sun <[email protected]>

* bump compat entry for ConstructionBase.jl

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: Xianda Sun <[email protected]>
  • Loading branch information
4 people authored Sep 8, 2023
1 parent 52cd7f9 commit ffe9272
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 20 deletions.
26 changes: 14 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.16"
version = "0.23.17"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -16,32 +16,34 @@ LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[compat]
AbstractMCMC = "2, 3.0, 4"
AbstractPPL = "0.6"
BangBang = "0.3"
Bijectors = "0.13"
ChainRulesCore = "0.9.7, 0.10, 1"
ConstructionBase = "1"
ConstructionBase = "1.5.4"
Distributions = "0.23.8, 0.24, 0.25"
DocStringExtensions = "0.8, 0.9"
LogDensityProblems = "2"
MacroTools = "0.5.6"
MCMCChains = "6"
MacroTools = "0.5.6"
OrderedCollections = "1"
Requires = "1"
Setfield = "0.7.1, 0.8, 1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
DynamicPPLMCMCChainsExt = ["MCMCChains"]

[extras]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"

[weakdeps]
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
45 changes: 39 additions & 6 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
module DynamicPPLMCMCChainsExt

using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL
using MCMCChains: MCMCChains
else
using ..DynamicPPL: DynamicPPL
using ..MCMCChains: MCMCChains
end

_has_varname_to_symbol(info::NamedTuple{names}) where {names} = :varname_to_symbol in names
function _check_varname_indexing(c::MCMCChains.Chains)
return DynamicPPL.supports_varname_indexing(c) ||
error("Chains do not support indexing using $vn.")
end

# A few methods needed.
function DynamicPPL.supports_varname_indexing(chain::MCMCChains.Chains)
return _has_varname_to_symbol(chain.info)
end
function DynamicPPL.getindex_varname(
c::MCMCChains.Chains, sample_idx, vn::DynamicPPL.VarName, chain_idx
)
_check_varname_indexing(c)
return c[sample_idx, c.info.varname_to_symbol[vn], chain_idx]
end
function DynamicPPL.varnames(c::MCMCChains.Chains)
_check_varname_indexing(c)
return keys(c.info.varname_to_symbol)
end

function DynamicPPL.generated_quantities(model::DynamicPPL.Model, chain::MCMCChains.Chains)
chain_parameters = MCMCChains.get_sections(chain, :parameters)
function DynamicPPL.generated_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
varinfo = DynamicPPL.VarInfo(model)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
return map(iters) do (sample_idx, chain_idx)
DynamicPPL.setval_and_resample!(varinfo, chain_parameters, sample_idx, chain_idx)
model(varinfo)
# Update the varinfo with the current sample and make variables not present in `chain`
# to be sampled.
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)

# TODO: Some of the variables can be a view into the `varinfo`, so we need to
# `deepcopy` the `varinfo` before passing it to `model`.
model(deepcopy(varinfo))
end
end

Expand Down
13 changes: 13 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ const LEGACY_WARNING = """
# Necessary forward declarations
include("utils.jl")
include("selector.jl")
include("chains.jl")
include("model.jl")
include("sampler.jl")
include("varname.jl")
Expand All @@ -175,4 +176,16 @@ include("logdensityfunction.jl")
include("model_utils.jl")
include("extract_priors.jl")

if !isdefined(Base, :get_extension)
using Requires
end

@static if !isdefined(Base, :get_extension)
function __init__()
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
"../ext/DynamicPPLMCMCChainsExt.jl"
)
end
end

end # module
25 changes: 25 additions & 0 deletions src/chains.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""
supports_varname_indexing(chain::AbstractChains)
Return `true` if `chain` supports indexing using `VarName` in place of the
variable name index.
"""
supports_varname_indexing(::AbstractChains) = false

"""
getindex_varname(chain::AbstractChains, sample_idx, varname::VarName, chain_idx)
Return the value of `varname` in `chain` at `sample_idx` and `chain_idx`.
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function getindex_varname end

"""
varnames(chains::AbstractChains)
Return an iterator over the varnames present in `chains`.
Whether this method is implemented for `chains` is indicated by [`supports_varname_indexing`](@ref).
"""
function varnames end
36 changes: 36 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,42 @@ function splitlens(condition, lens)
return current_parent, current_child, condition(current_parent)
end

"""
remove_parent_lens(vn_parent::VarName, vn_child::VarName)
Remove the parent lens `vn_parent` from `vn_child`.
# Examples
```jldoctest
julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a))
(@lens _.a)
julia> DynamicPPL.remove_parent_lens(@varname(x), @varname(x.a[1]))
(@lens _.a[1])
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1]))
(@lens _[1])
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a[1].b))
(@lens _[1].b)
julia> DynamicPPL.remove_parent_lens(@varname(x.a), @varname(x.a))
ERROR: Could not find x.a in x.a
julia> DynamicPPL.remove_parent_lens(@varname(x.a[2]), @varname(x.a[1]))
ERROR: Could not find x.a[2] in x.a[1]
```
"""
function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) where {sym}
_, child, issuccess = splitlens(getlens(vn_child)) do lens
l = lens === nothing ? Setfield.IdentityLens() : lens
VarName(vn_child, l) == vn_parent
end

issuccess || error("Could not find $vn_parent in $vn_child")
return child
end

# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
Expand Down
59 changes: 57 additions & 2 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,41 @@ end
return Expr(:||, false, out...)
end

function nested_setindex_maybe!(vi::UntypedVarInfo, val, vn::VarName)
return _nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
end
function nested_setindex_maybe!(
vi::VarInfo{<:NamedTuple{names}}, val, vn::VarName{sym}
) where {names,sym}
return if sym in names
_nested_setindex_maybe!(vi, getmetadata(vi, vn), val, vn)
else
nothing
end
end
function _nested_setindex_maybe!(vi::VarInfo, md::Metadata, val, vn::VarName)
# If `vn` is in `vns`, then we can just use the standard `setindex!`.
vns = md.vns
if vn in vns
setindex!(vi, val, vn)
return vn
end

# Otherwise, we need to check if either of the `vns` subsumes `vn`.
i = findfirst(Base.Fix2(subsumes, vn), vns)
i === nothing && return nothing

vn_parent = vns[i]
dist = getdist(md, vn_parent)
val_parent = getindex(vi, vn_parent, dist) # TODO: Ensure that we're working with a view here.
# Split the varname into its tail lens.
lens = remove_parent_lens(vn_parent, vn)
# Update the value for the parent.
val_parent_updated = set!!(val_parent, lens, val)
setindex!(vi, val_parent_updated, vn_parent)
return vn_parent
end

# The default getindex & setindex!() for get & set values
# NOTE: vi[vn] will always transform the variable to its original space and Julia type
getindex(vi::VarInfo, vn::VarName) = getindex(vi, vn, getdist(vi, vn))
Expand Down Expand Up @@ -1131,7 +1166,8 @@ The value(s) may or may not be transformed to Euclidean space.
"""
setindex!(vi::VarInfo, val, vn::VarName) = (setval!(vi, val, vn); return vi)
function BangBang.setindex!!(vi::VarInfo, val, vn::VarName)
return (setindex!(vi, val, vn); return vi)
setindex!(vi, val, vn)
return vi
end

"""
Expand Down Expand Up @@ -1600,7 +1636,26 @@ end
function setval_and_resample!(
vi::VarInfoOrThreadSafeVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int
)
return setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
if supports_varname_indexing(chains)
# First we need to set every variable to be resampled.
for vn in keys(vi)
set_flag!(vi, vn, "del")
end
# Then we set the variables in `varinfo` from `chain`.
for vn in varnames(chains)
vn_updated = nested_setindex_maybe!(
vi, getindex_varname(chains, sample_idx, vn, chain_idx), vn
)

# Unset the `del` flag if we found something.
if vn_updated !== nothing
# NOTE: This will be triggered even if only a subset of a variable has been set!
unset_flag!(vi, vn_updated, "del")
end
end
else
setval_and_resample!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end
end

function _setval_and_resample_kernel!(
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
AbstractPPL = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Expand All @@ -24,6 +25,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
AbstractMCMC = "2.1, 3.0, 4"
AbstractPPL = "0.6"
Bijectors = "0.13"
Compat = "4.3.0"
Distributions = "0.25"
DistributionsAD = "0.6.3"
Documenter = "0.26.1, 0.27"
Expand Down
51 changes: 51 additions & 0 deletions test/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,4 +278,55 @@ end
@test DynamicPPL.TestUtils.posterior_mean(model) isa typeof(x)
end
end

@testset "generated_quantities on `LKJCholesky`" begin
n = 10
d = 2
model = DynamicPPL.TestUtils.demo_lkjchol(d)
xs = [model().x for _ in 1:n]

# Extract varnames and values.
vns_and_vals_xs = map(
collect Base.Fix1(DynamicPPL.varname_and_value_leaves, @varname(x)), xs
)
vns = map(first, first(vns_and_vals_xs))
vals = map(vns_and_vals_xs) do vns_and_vals
map(last, vns_and_vals)
end

# Construct the chain.
syms = map(Symbol, vns)
vns_to_syms = OrderedDict{VarName,Any}(zip(vns, syms))

chain = MCMCChains.Chains(
permutedims(stack(vals)), syms; info=(varname_to_symbol=vns_to_syms,)
)
display(chain)

# Test!
results = generated_quantities(model, chain)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end

# With variables that aren't in the `model`.
vns_to_syms_with_extra = let d = deepcopy(vns_to_syms)
d[@varname(y)] = :y
d
end
vals_with_extra = map(enumerate(vals)) do (i, v)
vcat(v, i)
end
chain_with_extra = MCMCChains.Chains(
permutedims(stack(vals_with_extra)),
vcat(syms, [:y]);
info=(varname_to_symbol=vns_to_syms_with_extra,),
)
display(chain_with_extra)
# Test!
results = generated_quantities(model, chain_with_extra)
for (x_true, result) in zip(xs, results)
@test x_true.UL == result.x.UL
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using MCMCChains
using Tracker
using Zygote
using Setfield
using Compat

using Distributed
using LinearAlgebra
Expand Down

2 comments on commit ffe9272

@yebai
Copy link
Member

@yebai yebai commented on ffe9272 Sep 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/91113

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.23.17 -m "<description of version>" ffe92722682761f0a953da2135716ba99ed1ac75
git push origin v0.23.17

Please sign in to comment.