Skip to content

Commit

Permalink
Fix for LogDensityFunction (#621)
Browse files Browse the repository at this point in the history
* lazily resolve context to avoid overriding the model context

* bump patch version

* Update src/logdensityfunction.jl

* Update src/logdensityfunction.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* replaces more references to `f.context` with `getcontext(f)`

* Bump version to v0.28

Co-authored-by: Hong Ge <[email protected]>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Markus Hauru <[email protected]>
Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
4 people authored Jun 25, 2024
1 parent d384da2 commit 2b97177
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.27.2"
version = "0.28"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
20 changes: 13 additions & 7 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C}
varinfo::V
"model used for evaluation"
model::M
"context used for evaluation"
"context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable"
context::C
end

Expand All @@ -66,15 +66,20 @@ end
function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model),
context::AbstractContext=model.context,
context::Union{Nothing,AbstractContext}=nothing,
)
return LogDensityFunction(varinfo, model, context)
end

# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`.
function getcontext(f::LogDensityFunction)
return f.context === nothing ? leafcontext(f.model.context) : f.context
end

# HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time
# we need to define these annoying methods to ensure that we stay compatible with everything.
getsampler(f::LogDensityFunction) = getsampler(f.context)
hassampler(f::LogDensityFunction) = hassampler(f.context)
getsampler(f::LogDensityFunction) = getsampler(getcontext(f))
hassampler(f::LogDensityFunction) = hassampler(getcontext(f))

_get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx)
_get_indexer(ctx::SamplingContext) = ctx.sampler
Expand All @@ -86,12 +91,13 @@ _get_indexer(::IsLeaf, ctx::AbstractContext) = Colon()
Return the parameters of the wrapped varinfo as a vector.
"""
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)]
getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))]

# LogDensityProblems interface
function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector)
vi_new = unflatten(f.varinfo, f.context, θ)
return getlogp(last(evaluate!!(f.model, vi_new, f.context)))
context = getcontext(f)
vi_new = unflatten(f.varinfo, context, θ)
return getlogp(last(evaluate!!(f.model, vi_new, context)))
end
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction})
return LogDensityProblems.LogDensityOrder{0}()
Expand Down

2 comments on commit 2b97177

@yebai
Copy link
Member

@yebai yebai commented on 2b97177 Jun 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/109771

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.28.0 -m "<description of version>" 2b971773461711349cd7ce476bfa6a73e3093819
git push origin v0.28.0

Please sign in to comment.