From 2b971773461711349cd7ce476bfa6a73e3093819 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Tue, 25 Jun 2024 17:53:23 +0100 Subject: [PATCH] Fix for `LogDensityFunction` (#621) * lazily resolve context to avoid overriding the model context * bump patch version * Update src/logdensityfunction.jl * Update src/logdensityfunction.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * replaces more references to `f.context` with `getcontext(f)` * Bump version to v0.28 Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Markus Hauru Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> --- Project.toml | 2 +- src/logdensityfunction.jl | 20 +++++++++++++------- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index a81112d74..78acb2566 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "DynamicPPL" uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8" -version = "0.27.2" +version = "0.28" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index 007dfef11..8935edc12 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -49,7 +49,7 @@ struct LogDensityFunction{V,M,C} varinfo::V "model used for evaluation" model::M - "context used for evaluation" + "context used for evaluation; if `nothing`, `leafcontext(model.context)` will be used when applicable" context::C end @@ -66,15 +66,20 @@ end function LogDensityFunction( model::Model, varinfo::AbstractVarInfo=VarInfo(model), - context::AbstractContext=model.context, + context::Union{Nothing,AbstractContext}=nothing, ) return LogDensityFunction(varinfo, model, context) end +# If a `context` has been specified, we use that. Otherwise we just use the leaf context of `model`. +function getcontext(f::LogDensityFunction) + return f.context === nothing ? leafcontext(f.model.context) : f.context +end + # HACK: heavy usage of `AbstractSampler` for, well, _everything_, is being phased out. In the mean time # we need to define these annoying methods to ensure that we stay compatible with everything. -getsampler(f::LogDensityFunction) = getsampler(f.context) -hassampler(f::LogDensityFunction) = hassampler(f.context) +getsampler(f::LogDensityFunction) = getsampler(getcontext(f)) +hassampler(f::LogDensityFunction) = hassampler(getcontext(f)) _get_indexer(ctx::AbstractContext) = _get_indexer(NodeTrait(ctx), ctx) _get_indexer(ctx::SamplingContext) = ctx.sampler @@ -86,12 +91,13 @@ _get_indexer(::IsLeaf, ctx::AbstractContext) = Colon() Return the parameters of the wrapped varinfo as a vector. """ -getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(f.context)] +getparams(f::LogDensityFunction) = f.varinfo[_get_indexer(getcontext(f))] # LogDensityProblems interface function LogDensityProblems.logdensity(f::LogDensityFunction, θ::AbstractVector) - vi_new = unflatten(f.varinfo, f.context, θ) - return getlogp(last(evaluate!!(f.model, vi_new, f.context))) + context = getcontext(f) + vi_new = unflatten(f.varinfo, context, θ) + return getlogp(last(evaluate!!(f.model, vi_new, context))) end function LogDensityProblems.capabilities(::Type{<:LogDensityFunction}) return LogDensityProblems.LogDensityOrder{0}()